Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
9c4763ec
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
9c4763ec
编写于
7月 05, 2022
作者:
小湉湉
提交者:
GitHub
7月 05, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2113 from yt605155624/rm_server_log
[server]log redundancy in server
上级
e4a8e153
4b1f82d3
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
226 addition
and
223 deletion
+226
-223
paddlespeech/cli/tts/infer.py
paddlespeech/cli/tts/infer.py
+1
-1
paddlespeech/server/bin/paddlespeech_client.py
paddlespeech/server/bin/paddlespeech_client.py
+0
-2
paddlespeech/server/engine/acs/python/acs_engine.py
paddlespeech/server/engine/acs/python/acs_engine.py
+11
-9
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
+22
-22
paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
...ch/server/engine/asr/online/paddleinference/asr_engine.py
+25
-23
paddlespeech/server/engine/asr/online/python/asr_engine.py
paddlespeech/server/engine/asr/online/python/asr_engine.py
+44
-42
paddlespeech/server/engine/asr/paddleinference/asr_engine.py
paddlespeech/server/engine/asr/paddleinference/asr_engine.py
+7
-7
paddlespeech/server/engine/asr/python/asr_engine.py
paddlespeech/server/engine/asr/python/asr_engine.py
+2
-2
paddlespeech/server/engine/cls/paddleinference/cls_engine.py
paddlespeech/server/engine/cls/paddleinference/cls_engine.py
+11
-10
paddlespeech/server/engine/cls/python/cls_engine.py
paddlespeech/server/engine/cls/python/cls_engine.py
+2
-2
paddlespeech/server/engine/engine_warmup.py
paddlespeech/server/engine/engine_warmup.py
+3
-3
paddlespeech/server/engine/text/python/text_engine.py
paddlespeech/server/engine/text/python/text_engine.py
+6
-5
paddlespeech/server/engine/tts/online/onnx/tts_engine.py
paddlespeech/server/engine/tts/online/onnx/tts_engine.py
+12
-14
paddlespeech/server/engine/tts/online/python/tts_engine.py
paddlespeech/server/engine/tts/online/python/tts_engine.py
+9
-15
paddlespeech/server/engine/tts/paddleinference/tts_engine.py
paddlespeech/server/engine/tts/paddleinference/tts_engine.py
+29
-28
paddlespeech/server/engine/tts/python/tts_engine.py
paddlespeech/server/engine/tts/python/tts_engine.py
+13
-13
paddlespeech/server/engine/vector/python/vector_engine.py
paddlespeech/server/engine/vector/python/vector_engine.py
+17
-14
paddlespeech/server/utils/audio_handler.py
paddlespeech/server/utils/audio_handler.py
+5
-6
paddlespeech/server/utils/audio_process.py
paddlespeech/server/utils/audio_process.py
+1
-1
paddlespeech/server/utils/onnx_infer.py
paddlespeech/server/utils/onnx_infer.py
+2
-2
paddlespeech/server/utils/util.py
paddlespeech/server/utils/util.py
+4
-2
未找到文件。
paddlespeech/cli/tts/infer.py
浏览文件 @
9c4763ec
...
...
@@ -382,7 +382,7 @@ class TTSExecutor(BaseExecutor):
text
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
else
:
print
(
"lang should in {'zh', 'en'}!"
)
logger
.
error
(
"lang should in {'zh', 'en'}!"
)
self
.
frontend_time
=
time
.
time
()
-
frontend_st
self
.
am_time
=
0
...
...
paddlespeech/server/bin/paddlespeech_client.py
浏览文件 @
9c4763ec
...
...
@@ -123,7 +123,6 @@ class TTSClientExecutor(BaseExecutor):
time_end
=
time
.
time
()
time_consume
=
time_end
-
time_start
response_dict
=
res
.
json
()
logger
.
info
(
response_dict
[
"message"
])
logger
.
info
(
"Save synthesized audio successfully on %s."
%
(
output
))
logger
.
info
(
"Audio duration: %f s."
%
(
response_dict
[
'result'
][
'duration'
]))
...
...
@@ -702,7 +701,6 @@ class VectorClientExecutor(BaseExecutor):
test_audio
=
args
.
test
,
task
=
task
)
time_end
=
time
.
time
()
logger
.
info
(
f
"The vector:
{
res
}
"
)
logger
.
info
(
"Response time %f s."
%
(
time_end
-
time_start
))
return
True
except
Exception
as
e
:
...
...
paddlespeech/server/engine/acs/python/acs_engine.py
浏览文件 @
9c4763ec
...
...
@@ -30,7 +30,7 @@ class ACSEngine(BaseEngine):
"""The ACSEngine Engine
"""
super
(
ACSEngine
,
self
).
__init__
()
logger
.
info
(
"Create the ACSEngine Instance"
)
logger
.
debug
(
"Create the ACSEngine Instance"
)
self
.
word_list
=
[]
def
init
(
self
,
config
:
dict
):
...
...
@@ -42,7 +42,7 @@ class ACSEngine(BaseEngine):
Returns:
bool: The engine instance flag
"""
logger
.
info
(
"Init the acs engine"
)
logger
.
debug
(
"Init the acs engine"
)
try
:
self
.
config
=
config
self
.
device
=
self
.
config
.
get
(
"device"
,
paddle
.
get_device
())
...
...
@@ -50,7 +50,7 @@ class ACSEngine(BaseEngine):
# websocket default ping timeout is 20 seconds
self
.
ping_timeout
=
self
.
config
.
get
(
"ping_timeout"
,
20
)
paddle
.
set_device
(
self
.
device
)
logger
.
info
(
f
"ACS Engine set the device:
{
self
.
device
}
"
)
logger
.
debug
(
f
"ACS Engine set the device:
{
self
.
device
}
"
)
except
BaseException
as
e
:
logger
.
error
(
...
...
@@ -66,7 +66,9 @@ class ACSEngine(BaseEngine):
self
.
url
=
"ws://"
+
self
.
config
.
asr_server_ip
+
":"
+
str
(
self
.
config
.
asr_server_port
)
+
"/paddlespeech/asr/streaming"
logger
.
info
(
"Init the acs engine successfully"
)
logger
.
info
(
"Initialize acs server engine successfully on device: %s."
%
(
self
.
device
))
return
True
def
read_search_words
(
self
):
...
...
@@ -95,12 +97,12 @@ class ACSEngine(BaseEngine):
Returns:
_type_: _description_
"""
logger
.
info
(
"send a message to the server"
)
logger
.
debug
(
"send a message to the server"
)
if
self
.
url
is
None
:
logger
.
error
(
"No asr server, please input valid ip and port"
)
return
""
ws
=
websocket
.
WebSocket
()
logger
.
info
(
f
"set the ping timeout:
{
self
.
ping_timeout
}
seconds"
)
logger
.
debug
(
f
"set the ping timeout:
{
self
.
ping_timeout
}
seconds"
)
ws
.
connect
(
self
.
url
,
ping_timeout
=
self
.
ping_timeout
)
audio_info
=
json
.
dumps
(
{
...
...
@@ -123,7 +125,7 @@ class ACSEngine(BaseEngine):
logger
.
info
(
f
"audio result:
{
msg
}
"
)
# 3. send chunk audio data to engine
logger
.
info
(
"send the end signal"
)
logger
.
debug
(
"send the end signal"
)
audio_info
=
json
.
dumps
(
{
"name"
:
"test.wav"
,
...
...
@@ -197,7 +199,7 @@ class ACSEngine(BaseEngine):
start
=
max
(
time_stamp
[
m
.
start
(
0
)][
'bg'
]
-
offset
,
0
)
end
=
min
(
time_stamp
[
m
.
end
(
0
)
-
1
][
'ed'
]
+
offset
,
max_ed
)
logger
.
info
(
f
'start:
{
start
}
, end:
{
end
}
'
)
logger
.
debug
(
f
'start:
{
start
}
, end:
{
end
}
'
)
acs_result
.
append
({
'w'
:
w
,
'bg'
:
start
,
'ed'
:
end
})
return
acs_result
,
asr_result
...
...
@@ -212,7 +214,7 @@ class ACSEngine(BaseEngine):
Returns:
acs_result, asr_result: the acs result and the asr result
"""
logger
.
info
(
"start to process the audio content search"
)
logger
.
debug
(
"start to process the audio content search"
)
msg
=
self
.
get_asr_content
(
io
.
BytesIO
(
audio_data
))
acs_result
,
asr_result
=
self
.
get_macthed_word
(
msg
)
...
...
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
浏览文件 @
9c4763ec
...
...
@@ -44,7 +44,7 @@ class PaddleASRConnectionHanddler:
asr_engine (ASREngine): the global asr engine
"""
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"create an paddle asr connection handler to process the websocket connection"
)
self
.
config
=
asr_engine
.
config
# server config
...
...
@@ -152,12 +152,12 @@ class PaddleASRConnectionHanddler:
self
.
output_reset
()
def
extract_feat
(
self
,
samples
:
ByteString
):
logger
.
info
(
"Online ASR extract the feat"
)
logger
.
debug
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
logger
.
debug
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
...
...
@@ -168,7 +168,7 @@ class PaddleASRConnectionHanddler:
else
:
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
logger
.
debug
(
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
)
...
...
@@ -202,14 +202,14 @@ class PaddleASRConnectionHanddler:
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
logger
.
debug
(
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
logger
.
debug
(
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
logger
.
debug
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
debug
(
f
"global frames:
{
self
.
num_frames
}
"
)
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
...
...
@@ -237,7 +237,7 @@ class PaddleASRConnectionHanddler:
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
logger
.
debug
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
...
...
@@ -355,7 +355,7 @@ class ASRServerExecutor(ASRExecutor):
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
logger
.
debug
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
...
...
@@ -367,7 +367,7 @@ class ASRServerExecutor(ASRExecutor):
if
"deepspeech2"
in
self
.
model_type
:
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
logger
.
debug
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor
=
onnx_infer
.
get_sess
(
model_path
=
self
.
am_model
,
sess_conf
=
self
.
am_predictor_conf
)
else
:
...
...
@@ -400,7 +400,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
num_decoding_left_chunks
=
num_decoding_left_chunks
# conf for paddleinference predictor or onnx
self
.
am_predictor_conf
=
am_predictor_conf
logger
.
info
(
f
"model_type:
{
self
.
model_type
}
"
)
logger
.
debug
(
f
"model_type:
{
self
.
model_type
}
"
)
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
...
...
@@ -422,12 +422,11 @@ class ASRServerExecutor(ASRExecutor):
# self.res_path, self.task_resource.res_dict[
# 'params']) if am_params is None else os.path.abspath(am_params)
logger
.
info
(
"Load the pretrained model:"
)
logger
.
info
(
f
" tag =
{
tag
}
"
)
logger
.
info
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
info
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
info
(
f
" am_model path:
{
self
.
am_model
}
"
)
# logger.info(f" am_params path: {self.am_params}")
logger
.
debug
(
"Load the pretrained model:"
)
logger
.
debug
(
f
" tag =
{
tag
}
"
)
logger
.
debug
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
debug
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
debug
(
f
" am_model path:
{
self
.
am_model
}
"
)
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
...
...
@@ -436,7 +435,7 @@ class ASRServerExecutor(ASRExecutor):
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
logger
.
info
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
logger
.
debug
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
self
.
vocab
=
self
.
config
.
vocab_filepath
...
...
@@ -450,7 +449,7 @@ class ASRServerExecutor(ASRExecutor):
# AM predictor
self
.
init_model
()
logger
.
info
(
f
"create the
{
model_type
}
model success"
)
logger
.
debug
(
f
"create the
{
model_type
}
model success"
)
return
True
...
...
@@ -501,7 +500,7 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
sys
.
exit
(
-
1
)
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
logger
.
debug
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
if
not
self
.
init_model
():
logger
.
error
(
...
...
@@ -509,7 +508,8 @@ class ASREngine(BaseEngine):
)
return
False
logger
.
info
(
"Initialize ASR server engine successfully."
)
logger
.
info
(
"Initialize ASR server engine successfully on device: %s."
%
(
self
.
device
))
return
True
def
new_handler
(
self
):
...
...
paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
浏览文件 @
9c4763ec
...
...
@@ -44,7 +44,7 @@ class PaddleASRConnectionHanddler:
asr_engine (ASREngine): the global asr engine
"""
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"create an paddle asr connection handler to process the websocket connection"
)
self
.
config
=
asr_engine
.
config
# server config
...
...
@@ -157,7 +157,7 @@ class PaddleASRConnectionHanddler:
assert
samples
.
ndim
==
1
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
logger
.
debug
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
...
...
@@ -168,7 +168,7 @@ class PaddleASRConnectionHanddler:
else
:
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
logger
.
debug
(
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
)
...
...
@@ -202,14 +202,14 @@ class PaddleASRConnectionHanddler:
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
logger
.
debug
(
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
logger
.
debug
(
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
logger
.
debug
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
debug
(
f
"global frames:
{
self
.
num_frames
}
"
)
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
...
...
@@ -237,13 +237,13 @@ class PaddleASRConnectionHanddler:
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
logger
.
debug
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
logger
.
debug
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
return
None
,
None
...
...
@@ -294,7 +294,7 @@ class PaddleASRConnectionHanddler:
Returns:
logprob: poster probability.
"""
logger
.
info
(
"start to decoce one chunk for deepspeech2"
)
logger
.
debug
(
"start to decoce one chunk for deepspeech2"
)
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
...
...
@@ -369,7 +369,7 @@ class ASRServerExecutor(ASRExecutor):
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
logger
.
debug
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
...
...
@@ -381,7 +381,7 @@ class ASRServerExecutor(ASRExecutor):
if
"deepspeech2"
in
self
.
model_type
:
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
logger
.
debug
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
...
...
@@ -415,7 +415,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
num_decoding_left_chunks
=
num_decoding_left_chunks
# conf for paddleinference predictor or onnx
self
.
am_predictor_conf
=
am_predictor_conf
logger
.
info
(
f
"model_type:
{
self
.
model_type
}
"
)
logger
.
debug
(
f
"model_type:
{
self
.
model_type
}
"
)
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
...
...
@@ -437,12 +437,12 @@ class ASRServerExecutor(ASRExecutor):
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
"Load the pretrained model:"
)
logger
.
info
(
f
" tag =
{
tag
}
"
)
logger
.
info
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
info
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
info
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
info
(
f
" am_params path:
{
self
.
am_params
}
"
)
logger
.
debug
(
"Load the pretrained model:"
)
logger
.
debug
(
f
" tag =
{
tag
}
"
)
logger
.
debug
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
debug
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
debug
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
debug
(
f
" am_params path:
{
self
.
am_params
}
"
)
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
...
...
@@ -451,7 +451,7 @@ class ASRServerExecutor(ASRExecutor):
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
logger
.
info
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
logger
.
debug
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
self
.
vocab
=
self
.
config
.
vocab_filepath
...
...
@@ -465,7 +465,7 @@ class ASRServerExecutor(ASRExecutor):
# AM predictor
self
.
init_model
()
logger
.
info
(
f
"create the
{
model_type
}
model success"
)
logger
.
debug
(
f
"create the
{
model_type
}
model success"
)
return
True
...
...
@@ -516,7 +516,7 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
sys
.
exit
(
-
1
)
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
logger
.
debug
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
if
not
self
.
init_model
():
logger
.
error
(
...
...
@@ -524,7 +524,9 @@ class ASREngine(BaseEngine):
)
return
False
logger
.
info
(
"Initialize ASR server engine successfully."
)
logger
.
info
(
"Initialize ASR server engine successfully on device: %s."
%
(
self
.
device
))
return
True
def
new_handler
(
self
):
...
...
paddlespeech/server/engine/asr/online/python/asr_engine.py
浏览文件 @
9c4763ec
...
...
@@ -49,7 +49,7 @@ class PaddleASRConnectionHanddler:
asr_engine (ASREngine): the global asr engine
"""
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"create an paddle asr connection handler to process the websocket connection"
)
self
.
config
=
asr_engine
.
config
# server config
...
...
@@ -107,7 +107,7 @@ class PaddleASRConnectionHanddler:
# acoustic model
self
.
model
=
self
.
asr_engine
.
executor
.
model
self
.
continuous_decoding
=
self
.
config
.
continuous_decoding
logger
.
info
(
f
"continue decoding:
{
self
.
continuous_decoding
}
"
)
logger
.
debug
(
f
"continue decoding:
{
self
.
continuous_decoding
}
"
)
# ctc decoding config
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
...
...
@@ -207,7 +207,7 @@ class PaddleASRConnectionHanddler:
assert
samples
.
ndim
==
1
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
logger
.
debug
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
...
...
@@ -218,7 +218,7 @@ class PaddleASRConnectionHanddler:
else
:
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
logger
.
debug
(
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
)
...
...
@@ -252,14 +252,14 @@ class PaddleASRConnectionHanddler:
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
logger
.
debug
(
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
logger
.
debug
(
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
logger
.
debug
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
debug
(
f
"global frames:
{
self
.
num_frames
}
"
)
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
...
...
@@ -283,24 +283,24 @@ class PaddleASRConnectionHanddler:
stride
=
subsampling
*
decoding_chunk_size
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
logger
.
debug
(
"no audio feat, please input more pcm data"
)
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
logger
.
debug
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
logger
.
debug
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
return
None
,
None
# if is_finished=True, we need at least context frames
if
num_frames
<
context
:
logger
.
info
(
logger
.
debug
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return
None
,
None
...
...
@@ -354,7 +354,7 @@ class PaddleASRConnectionHanddler:
Returns:
logprob: poster probability.
"""
logger
.
info
(
"start to decoce one chunk for deepspeech2"
)
logger
.
debug
(
"start to decoce one chunk for deepspeech2"
)
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
...
...
@@ -391,7 +391,7 @@ class PaddleASRConnectionHanddler:
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one best result for deepspeech2:
{
trans_best
[
0
]
}
"
)
logger
.
debug
(
f
"decode one best result for deepspeech2:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
@
paddle
.
no_grad
()
...
...
@@ -402,7 +402,7 @@ class PaddleASRConnectionHanddler:
# reset endpiont state
self
.
endpoint_state
=
False
logger
.
info
(
logger
.
debug
(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg
=
self
.
ctc_decode_config
...
...
@@ -427,25 +427,25 @@ class PaddleASRConnectionHanddler:
stride
=
subsampling
*
decoding_chunk_size
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
logger
.
debug
(
"no audio feat, please input more pcm data"
)
return
# (B=1,T,D)
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
logger
.
debug
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
logger
.
debug
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
return
None
,
None
# if is_finished=True, we need at least context frames
if
num_frames
<
context
:
logger
.
info
(
logger
.
debug
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return
None
,
None
...
...
@@ -489,7 +489,7 @@ class PaddleASRConnectionHanddler:
self
.
encoder_out
=
ys
else
:
self
.
encoder_out
=
paddle
.
concat
([
self
.
encoder_out
,
ys
],
axis
=
1
)
logger
.
info
(
logger
.
debug
(
f
"This connection handler encoder out shape:
{
self
.
encoder_out
.
shape
}
"
)
...
...
@@ -513,7 +513,8 @@ class PaddleASRConnectionHanddler:
if
self
.
endpointer
.
endpoint_detected
(
ctc_probs
.
numpy
(),
decoding_something
):
self
.
endpoint_state
=
True
logger
.
info
(
f
"Endpoint is detected at
{
self
.
num_frames
}
frame."
)
logger
.
debug
(
f
"Endpoint is detected at
{
self
.
num_frames
}
frame."
)
# advance cache of feat
assert
self
.
cached_feat
.
shape
[
0
]
==
1
#(B=1,T,D)
...
...
@@ -526,7 +527,7 @@ class PaddleASRConnectionHanddler:
def
update_result
(
self
):
"""Conformer/Transformer hyps to result.
"""
logger
.
info
(
"update the final result"
)
logger
.
debug
(
"update the final result"
)
hyps
=
self
.
hyps
# output results and tokenids
...
...
@@ -560,16 +561,16 @@ class PaddleASRConnectionHanddler:
only for conformer and transformer model.
"""
if
"deepspeech2"
in
self
.
model_type
:
logger
.
info
(
"deepspeech2 not support rescoring decoding."
)
logger
.
debug
(
"deepspeech2 not support rescoring decoding."
)
return
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
logger
.
info
(
logger
.
debug
(
f
"decoding method not match:
{
self
.
ctc_decode_config
.
decoding_method
}
, need attention_rescoring"
)
return
logger
.
info
(
"rescoring the final result"
)
logger
.
debug
(
"rescoring the final result"
)
# last decoding for last audio
self
.
searcher
.
finalize_search
()
...
...
@@ -685,7 +686,6 @@ class PaddleASRConnectionHanddler:
"bg"
:
global_offset_in_sec
+
start
,
"ed"
:
global_offset_in_sec
+
end
})
# logger.info(f"{word_time_stamp[-1]}")
self
.
word_time_stamp
=
word_time_stamp
logger
.
info
(
f
"word time stamp:
{
self
.
word_time_stamp
}
"
)
...
...
@@ -707,13 +707,13 @@ class ASRServerExecutor(ASRExecutor):
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
logger
.
debug
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
with
UpdateConfig
(
self
.
config
):
logger
.
info
(
"start to create the stream conformer asr engine"
)
logger
.
debug
(
"start to create the stream conformer asr engine"
)
# update the decoding method
if
self
.
decode_method
:
self
.
config
.
decode
.
decoding_method
=
self
.
decode_method
...
...
@@ -726,7 +726,7 @@ class ASRServerExecutor(ASRExecutor):
if
self
.
config
.
decode
.
decoding_method
not
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
]:
logger
.
info
(
logger
.
debug
(
"we set the decoding_method to attention_rescoring"
)
self
.
config
.
decode
.
decoding_method
=
"attention_rescoring"
...
...
@@ -739,7 +739,7 @@ class ASRServerExecutor(ASRExecutor):
def
init_model
(
self
)
->
None
:
if
"deepspeech2"
in
self
.
model_type
:
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
logger
.
debug
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
...
...
@@ -748,7 +748,7 @@ class ASRServerExecutor(ASRExecutor):
# load model
# model_type: {model_name}_{dataset}
model_name
=
self
.
model_type
[:
self
.
model_type
.
rindex
(
'_'
)]
logger
.
info
(
f
"model name:
{
model_name
}
"
)
logger
.
debug
(
f
"model name:
{
model_name
}
"
)
model_class
=
self
.
task_resource
.
get_model_class
(
model_name
)
model
=
model_class
.
from_config
(
self
.
config
)
self
.
model
=
model
...
...
@@ -782,7 +782,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
num_decoding_left_chunks
=
num_decoding_left_chunks
# conf for paddleinference predictor or onnx
self
.
am_predictor_conf
=
am_predictor_conf
logger
.
info
(
f
"model_type:
{
self
.
model_type
}
"
)
logger
.
debug
(
f
"model_type:
{
self
.
model_type
}
"
)
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
...
...
@@ -804,12 +804,12 @@ class ASRServerExecutor(ASRExecutor):
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
"Load the pretrained model:"
)
logger
.
info
(
f
" tag =
{
tag
}
"
)
logger
.
info
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
info
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
info
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
info
(
f
" am_params path:
{
self
.
am_params
}
"
)
logger
.
debug
(
"Load the pretrained model:"
)
logger
.
debug
(
f
" tag =
{
tag
}
"
)
logger
.
debug
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
debug
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
debug
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
debug
(
f
" am_params path:
{
self
.
am_params
}
"
)
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
...
...
@@ -818,7 +818,7 @@ class ASRServerExecutor(ASRExecutor):
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
logger
.
info
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
logger
.
debug
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
self
.
vocab
=
self
.
config
.
vocab_filepath
...
...
@@ -832,7 +832,7 @@ class ASRServerExecutor(ASRExecutor):
# AM predictor
self
.
init_model
()
logger
.
info
(
f
"create the
{
model_type
}
model success"
)
logger
.
debug
(
f
"create the
{
model_type
}
model success"
)
return
True
...
...
@@ -883,7 +883,7 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
sys
.
exit
(
-
1
)
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
logger
.
debug
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
if
not
self
.
init_model
():
logger
.
error
(
...
...
@@ -891,7 +891,9 @@ class ASREngine(BaseEngine):
)
return
False
logger
.
info
(
"Initialize ASR server engine successfully."
)
logger
.
info
(
"Initialize ASR server engine successfully on device: %s."
%
(
self
.
device
))
return
True
def
new_handler
(
self
):
...
...
paddlespeech/server/engine/asr/paddleinference/asr_engine.py
浏览文件 @
9c4763ec
...
...
@@ -65,10 +65,10 @@ class ASRServerExecutor(ASRExecutor):
self
.
task_resource
.
res_dict
[
'model'
])
self
.
am_params
=
os
.
path
.
join
(
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'params'
])
logger
.
info
(
self
.
res_path
)
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
am_model
)
logger
.
info
(
self
.
am_params
)
logger
.
debug
(
self
.
res_path
)
logger
.
debug
(
self
.
cfg_path
)
logger
.
debug
(
self
.
am_model
)
logger
.
debug
(
self
.
am_params
)
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
...
...
@@ -236,16 +236,16 @@ class PaddleASRConnectionHandler(ASRServerExecutor):
if
self
.
_check
(
io
.
BytesIO
(
audio_data
),
self
.
asr_engine
.
config
.
sample_rate
,
self
.
asr_engine
.
config
.
force_yes
):
logger
.
info
(
"start running asr engine"
)
logger
.
debug
(
"start running asr engine"
)
self
.
preprocess
(
self
.
asr_engine
.
config
.
model_type
,
io
.
BytesIO
(
audio_data
))
st
=
time
.
time
()
self
.
infer
(
self
.
asr_engine
.
config
.
model_type
)
infer_time
=
time
.
time
()
-
st
self
.
output
=
self
.
postprocess
()
# Retrieve result of asr.
logger
.
info
(
"end inferring asr engine"
)
logger
.
debug
(
"end inferring asr engine"
)
else
:
logger
.
info
(
"file check failed!"
)
logger
.
error
(
"file check failed!"
)
self
.
output
=
None
logger
.
info
(
"inference time: {}"
.
format
(
infer_time
))
...
...
paddlespeech/server/engine/asr/python/asr_engine.py
浏览文件 @
9c4763ec
...
...
@@ -104,7 +104,7 @@ class PaddleASRConnectionHandler(ASRServerExecutor):
if
self
.
_check
(
io
.
BytesIO
(
audio_data
),
self
.
asr_engine
.
config
.
sample_rate
,
self
.
asr_engine
.
config
.
force_yes
):
logger
.
info
(
"start run asr engine"
)
logger
.
debug
(
"start run asr engine"
)
self
.
preprocess
(
self
.
asr_engine
.
config
.
model
,
io
.
BytesIO
(
audio_data
))
st
=
time
.
time
()
...
...
@@ -112,7 +112,7 @@ class PaddleASRConnectionHandler(ASRServerExecutor):
infer_time
=
time
.
time
()
-
st
self
.
output
=
self
.
postprocess
()
# Retrieve result of asr.
else
:
logger
.
info
(
"file check failed!"
)
logger
.
error
(
"file check failed!"
)
self
.
output
=
None
logger
.
info
(
"inference time: {}"
.
format
(
infer_time
))
...
...
paddlespeech/server/engine/cls/paddleinference/cls_engine.py
浏览文件 @
9c4763ec
...
...
@@ -67,22 +67,22 @@ class CLSServerExecutor(CLSExecutor):
self
.
params_path
=
os
.
path
.
abspath
(
params_path
)
self
.
label_file
=
os
.
path
.
abspath
(
label_file
)
logger
.
info
(
self
.
cfg_path
)
logger
.
info
(
self
.
model_path
)
logger
.
info
(
self
.
params_path
)
logger
.
info
(
self
.
label_file
)
logger
.
debug
(
self
.
cfg_path
)
logger
.
debug
(
self
.
model_path
)
logger
.
debug
(
self
.
params_path
)
logger
.
debug
(
self
.
label_file
)
# config
with
open
(
self
.
cfg_path
,
'r'
)
as
f
:
self
.
_conf
=
yaml
.
safe_load
(
f
)
logger
.
info
(
"Read cfg file successfully."
)
logger
.
debug
(
"Read cfg file successfully."
)
# labels
self
.
_label_list
=
[]
with
open
(
self
.
label_file
,
'r'
)
as
f
:
for
line
in
f
:
self
.
_label_list
.
append
(
line
.
strip
())
logger
.
info
(
"Read label file successfully."
)
logger
.
debug
(
"Read label file successfully."
)
# Create predictor
self
.
predictor_conf
=
predictor_conf
...
...
@@ -90,7 +90,7 @@ class CLSServerExecutor(CLSExecutor):
model_file
=
self
.
model_path
,
params_file
=
self
.
params_path
,
predictor_conf
=
self
.
predictor_conf
)
logger
.
info
(
"Create predictor successfully."
)
logger
.
debug
(
"Create predictor successfully."
)
@
paddle
.
no_grad
()
def
infer
(
self
):
...
...
@@ -148,7 +148,8 @@ class CLSEngine(BaseEngine):
logger
.
error
(
e
)
return
False
logger
.
info
(
"Initialize CLS server engine successfully."
)
logger
.
info
(
"Initialize CLS server engine successfully on device: %s."
%
(
self
.
device
))
return
True
...
...
@@ -160,7 +161,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
cls_engine (CLSEngine): The CLS engine
"""
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleCLSConnectionHandler to process the cls request"
)
self
.
_inputs
=
OrderedDict
()
...
...
@@ -183,7 +184,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
self
.
infer
()
infer_time
=
time
.
time
()
-
st
logger
.
info
(
"inference time: {}"
.
format
(
infer_time
))
logger
.
debug
(
"inference time: {}"
.
format
(
infer_time
))
logger
.
info
(
"cls engine type: inference"
)
def
postprocess
(
self
,
topk
:
int
):
...
...
paddlespeech/server/engine/cls/python/cls_engine.py
浏览文件 @
9c4763ec
...
...
@@ -88,7 +88,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
cls_engine (CLSEngine): The CLS engine
"""
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleCLSConnectionHandler to process the cls request"
)
self
.
_inputs
=
OrderedDict
()
...
...
@@ -110,7 +110,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
self
.
infer
()
infer_time
=
time
.
time
()
-
st
logger
.
info
(
"inference time: {}"
.
format
(
infer_time
))
logger
.
debug
(
"inference time: {}"
.
format
(
infer_time
))
logger
.
info
(
"cls engine type: python"
)
def
postprocess
(
self
,
topk
:
int
):
...
...
paddlespeech/server/engine/engine_warmup.py
浏览文件 @
9c4763ec
...
...
@@ -45,7 +45,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
logger
.
error
(
"Please check tte engine type."
)
try
:
logger
.
info
(
"Start to warm up tts engine."
)
logger
.
debug
(
"Start to warm up tts engine."
)
for
i
in
range
(
warm_up_time
):
connection_handler
=
PaddleTTSConnectionHandler
(
tts_engine
)
if
flag_online
:
...
...
@@ -53,7 +53,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
text
=
sentence
,
lang
=
tts_engine
.
lang
,
am
=
tts_engine
.
config
.
am
):
logger
.
info
(
logger
.
debug
(
f
"The first response time of the
{
i
}
warm up:
{
connection_handler
.
first_response_time
}
s"
)
break
...
...
@@ -62,7 +62,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
st
=
time
.
time
()
connection_handler
.
infer
(
text
=
sentence
)
et
=
time
.
time
()
logger
.
info
(
logger
.
debug
(
f
"The response time of the
{
i
}
warm up:
{
et
-
st
}
s"
)
except
Exception
as
e
:
logger
.
error
(
"Failed to warm up on tts engine."
)
...
...
paddlespeech/server/engine/text/python/text_engine.py
浏览文件 @
9c4763ec
...
...
@@ -28,7 +28,7 @@ class PaddleTextConnectionHandler:
text_engine (TextEngine): The Text engine
"""
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleTextConnectionHandler to process the text request"
)
self
.
text_engine
=
text_engine
self
.
task
=
self
.
text_engine
.
executor
.
task
...
...
@@ -130,7 +130,7 @@ class TextEngine(BaseEngine):
"""The Text Engine
"""
super
(
TextEngine
,
self
).
__init__
()
logger
.
info
(
"Create the TextEngine Instance"
)
logger
.
debug
(
"Create the TextEngine Instance"
)
def
init
(
self
,
config
:
dict
):
"""Init the Text Engine
...
...
@@ -141,7 +141,7 @@ class TextEngine(BaseEngine):
Returns:
bool: The engine instance flag
"""
logger
.
info
(
"Init the text engine"
)
logger
.
debug
(
"Init the text engine"
)
try
:
self
.
config
=
config
if
self
.
config
.
device
:
...
...
@@ -150,7 +150,7 @@ class TextEngine(BaseEngine):
self
.
device
=
paddle
.
get_device
()
paddle
.
set_device
(
self
.
device
)
logger
.
info
(
f
"Text Engine set the device:
{
self
.
device
}
"
)
logger
.
debug
(
f
"Text Engine set the device:
{
self
.
device
}
"
)
except
BaseException
as
e
:
logger
.
error
(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
...
...
@@ -168,5 +168,6 @@ class TextEngine(BaseEngine):
ckpt_path
=
config
.
ckpt_path
,
vocab_file
=
config
.
vocab_file
)
logger
.
info
(
"Init the text engine successfully"
)
logger
.
info
(
"Initialize Text server engine successfully on device: %s."
%
(
self
.
device
))
return
True
paddlespeech/server/engine/tts/online/onnx/tts_engine.py
浏览文件 @
9c4763ec
...
...
@@ -62,7 +62,7 @@ class TTSServerExecutor(TTSExecutor):
(
hasattr
(
self
,
'am_encoder_infer_sess'
)
and
hasattr
(
self
,
'am_decoder_sess'
)
and
hasattr
(
self
,
'am_postnet_sess'
)))
and
hasattr
(
self
,
'voc_inference'
):
logger
.
info
(
'Models had been initialized.'
)
logger
.
debug
(
'Models had been initialized.'
)
return
# am
am_tag
=
am
+
'-'
+
lang
...
...
@@ -85,8 +85,7 @@ class TTSServerExecutor(TTSExecutor):
else
:
self
.
am_ckpt
=
os
.
path
.
abspath
(
am_ckpt
[
0
])
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
am_ckpt
))
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
am_ckpt
))
# create am sess
self
.
am_sess
=
get_sess
(
self
.
am_ckpt
,
am_sess_conf
)
...
...
@@ -119,8 +118,7 @@ class TTSServerExecutor(TTSExecutor):
self
.
am_postnet
=
os
.
path
.
abspath
(
am_ckpt
[
2
])
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
am_stat
=
os
.
path
.
abspath
(
am_stat
)
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
am_ckpt
[
0
]))
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
am_ckpt
[
0
]))
# create am sess
self
.
am_encoder_infer_sess
=
get_sess
(
self
.
am_encoder_infer
,
...
...
@@ -130,13 +128,13 @@ class TTSServerExecutor(TTSExecutor):
self
.
am_mu
,
self
.
am_std
=
np
.
load
(
self
.
am_stat
)
logger
.
info
(
f
"self.phones_dict:
{
self
.
phones_dict
}
"
)
logger
.
info
(
f
"am model dir:
{
self
.
am_res_path
}
"
)
logger
.
info
(
"Create am sess successfully."
)
logger
.
debug
(
f
"self.phones_dict:
{
self
.
phones_dict
}
"
)
logger
.
debug
(
f
"am model dir:
{
self
.
am_res_path
}
"
)
logger
.
debug
(
"Create am sess successfully."
)
# voc model info
voc_tag
=
voc
+
'-'
+
lang
if
voc_ckpt
is
None
:
self
.
task_resource
.
set_task_model
(
model_tag
=
voc_tag
,
...
...
@@ -149,16 +147,16 @@ class TTSServerExecutor(TTSExecutor):
else
:
self
.
voc_ckpt
=
os
.
path
.
abspath
(
voc_ckpt
)
self
.
voc_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
voc_ckpt
))
logger
.
info
(
self
.
voc_res_path
)
logger
.
debug
(
self
.
voc_res_path
)
# create voc sess
self
.
voc_sess
=
get_sess
(
self
.
voc_ckpt
,
voc_sess_conf
)
logger
.
info
(
"Create voc sess successfully."
)
logger
.
debug
(
"Create voc sess successfully."
)
with
open
(
self
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
self
.
vocab_size
=
len
(
phn_id
)
logger
.
info
(
f
"vocab_size:
{
self
.
vocab_size
}
"
)
logger
.
debug
(
f
"vocab_size:
{
self
.
vocab_size
}
"
)
# frontend
self
.
tones_dict
=
None
...
...
@@ -169,7 +167,7 @@ class TTSServerExecutor(TTSExecutor):
elif
lang
==
'en'
:
self
.
frontend
=
English
(
phone_vocab_path
=
self
.
phones_dict
)
logger
.
info
(
"frontend done!"
)
logger
.
debug
(
"frontend done!"
)
class
TTSEngine
(
BaseEngine
):
...
...
@@ -267,7 +265,7 @@ class PaddleTTSConnectionHandler:
tts_engine (TTSEngine): The TTS engine
"""
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleTTSConnectionHandler to process the tts request"
)
self
.
tts_engine
=
tts_engine
...
...
paddlespeech/server/engine/tts/online/python/tts_engine.py
浏览文件 @
9c4763ec
...
...
@@ -102,7 +102,7 @@ class TTSServerExecutor(TTSExecutor):
Init model and other resources from a specific path.
"""
if
hasattr
(
self
,
'am_inference'
)
and
hasattr
(
self
,
'voc_inference'
):
logger
.
info
(
'Models had been initialized.'
)
logger
.
debug
(
'Models had been initialized.'
)
return
# am model info
if
am_ckpt
is
None
or
am_config
is
None
or
am_stat
is
None
or
phones_dict
is
None
:
...
...
@@ -128,17 +128,15 @@ class TTSServerExecutor(TTSExecutor):
# must have phones_dict in acoustic
self
.
phones_dict
=
os
.
path
.
join
(
self
.
am_res_path
,
self
.
task_resource
.
res_dict
[
'phones_dict'
])
print
(
"self.phones_dict:"
,
self
.
phones_dict
)
logger
.
info
(
self
.
am_res_path
)
logger
.
info
(
self
.
am_config
)
logger
.
info
(
self
.
am_ckpt
)
logger
.
debug
(
self
.
am_res_path
)
logger
.
debug
(
self
.
am_config
)
logger
.
debug
(
self
.
am_ckpt
)
else
:
self
.
am_config
=
os
.
path
.
abspath
(
am_config
)
self
.
am_ckpt
=
os
.
path
.
abspath
(
am_ckpt
)
self
.
am_stat
=
os
.
path
.
abspath
(
am_stat
)
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
am_config
))
print
(
"self.phones_dict:"
,
self
.
phones_dict
)
self
.
tones_dict
=
None
self
.
speaker_dict
=
None
...
...
@@ -165,9 +163,9 @@ class TTSServerExecutor(TTSExecutor):
self
.
voc_stat
=
os
.
path
.
join
(
self
.
voc_res_path
,
self
.
task_resource
.
voc_res_dict
[
'speech_stats'
])
logger
.
info
(
self
.
voc_res_path
)
logger
.
info
(
self
.
voc_config
)
logger
.
info
(
self
.
voc_ckpt
)
logger
.
debug
(
self
.
voc_res_path
)
logger
.
debug
(
self
.
voc_config
)
logger
.
debug
(
self
.
voc_ckpt
)
else
:
self
.
voc_config
=
os
.
path
.
abspath
(
voc_config
)
self
.
voc_ckpt
=
os
.
path
.
abspath
(
voc_ckpt
)
...
...
@@ -184,7 +182,6 @@ class TTSServerExecutor(TTSExecutor):
with
open
(
self
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
self
.
vocab_size
=
len
(
phn_id
)
print
(
"vocab_size:"
,
self
.
vocab_size
)
# frontend
if
lang
==
'zh'
:
...
...
@@ -194,7 +191,6 @@ class TTSServerExecutor(TTSExecutor):
elif
lang
==
'en'
:
self
.
frontend
=
English
(
phone_vocab_path
=
self
.
phones_dict
)
print
(
"frontend done!"
)
# am infer info
self
.
am_name
=
am
[:
am
.
rindex
(
'_'
)]
...
...
@@ -209,7 +205,6 @@ class TTSServerExecutor(TTSExecutor):
self
.
am_name
+
'_inference'
)
self
.
am_inference
=
am_inference_class
(
am_normalizer
,
am
)
self
.
am_inference
.
eval
()
print
(
"acoustic model done!"
)
# voc infer info
self
.
voc_name
=
voc
[:
voc
.
rindex
(
'_'
)]
...
...
@@ -220,7 +215,6 @@ class TTSServerExecutor(TTSExecutor):
'_inference'
)
self
.
voc_inference
=
voc_inference_class
(
voc_normalizer
,
voc
)
self
.
voc_inference
.
eval
()
print
(
"voc done!"
)
class
TTSEngine
(
BaseEngine
):
...
...
@@ -309,7 +303,7 @@ class PaddleTTSConnectionHandler:
tts_engine (TTSEngine): The TTS engine
"""
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleTTSConnectionHandler to process the tts request"
)
self
.
tts_engine
=
tts_engine
...
...
@@ -369,7 +363,7 @@ class PaddleTTSConnectionHandler:
text
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
else
:
print
(
"lang should in {'zh', 'en'}!"
)
logger
.
error
(
"lang should in {'zh', 'en'}!"
)
frontend_et
=
time
.
time
()
self
.
frontend_time
=
frontend_et
-
frontend_st
...
...
paddlespeech/server/engine/tts/paddleinference/tts_engine.py
浏览文件 @
9c4763ec
...
...
@@ -65,7 +65,7 @@ class TTSServerExecutor(TTSExecutor):
Init model and other resources from a specific path.
"""
if
hasattr
(
self
,
'am_predictor'
)
and
hasattr
(
self
,
'voc_predictor'
):
logger
.
info
(
'Models had been initialized.'
)
logger
.
debug
(
'Models had been initialized.'
)
return
# am
if
am_model
is
None
or
am_params
is
None
or
phones_dict
is
None
:
...
...
@@ -91,16 +91,16 @@ class TTSServerExecutor(TTSExecutor):
self
.
am_res_path
,
self
.
task_resource
.
res_dict
[
'phones_dict'
])
self
.
am_sample_rate
=
self
.
task_resource
.
res_dict
[
'sample_rate'
]
logger
.
info
(
self
.
am_res_path
)
logger
.
info
(
self
.
am_model
)
logger
.
info
(
self
.
am_params
)
logger
.
debug
(
self
.
am_res_path
)
logger
.
debug
(
self
.
am_model
)
logger
.
debug
(
self
.
am_params
)
else
:
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
self
.
am_params
=
os
.
path
.
abspath
(
am_params
)
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
am_sample_rate
=
am_sample_rate
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
am_model
))
logger
.
info
(
"self.phones_dict: {}"
.
format
(
self
.
phones_dict
))
logger
.
debug
(
"self.phones_dict: {}"
.
format
(
self
.
phones_dict
))
# for speedyspeech
self
.
tones_dict
=
None
...
...
@@ -139,9 +139,9 @@ class TTSServerExecutor(TTSExecutor):
self
.
voc_res_path
,
self
.
task_resource
.
voc_res_dict
[
'params'
])
self
.
voc_sample_rate
=
self
.
task_resource
.
voc_res_dict
[
'sample_rate'
]
logger
.
info
(
self
.
voc_res_path
)
logger
.
info
(
self
.
voc_model
)
logger
.
info
(
self
.
voc_params
)
logger
.
debug
(
self
.
voc_res_path
)
logger
.
debug
(
self
.
voc_model
)
logger
.
debug
(
self
.
voc_params
)
else
:
self
.
voc_model
=
os
.
path
.
abspath
(
voc_model
)
self
.
voc_params
=
os
.
path
.
abspath
(
voc_params
)
...
...
@@ -156,21 +156,21 @@ class TTSServerExecutor(TTSExecutor):
with
open
(
self
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
vocab_size
=
len
(
phn_id
)
logger
.
info
(
"vocab_size: {}"
.
format
(
vocab_size
))
logger
.
debug
(
"vocab_size: {}"
.
format
(
vocab_size
))
tone_size
=
None
if
self
.
tones_dict
:
with
open
(
self
.
tones_dict
,
"r"
)
as
f
:
tone_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
tone_size
=
len
(
tone_id
)
logger
.
info
(
"tone_size: {}"
.
format
(
tone_size
))
logger
.
debug
(
"tone_size: {}"
.
format
(
tone_size
))
spk_num
=
None
if
self
.
speaker_dict
:
with
open
(
self
.
speaker_dict
,
'rt'
)
as
f
:
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
spk_num
=
len
(
spk_id
)
logger
.
info
(
"spk_num: {}"
.
format
(
spk_num
))
logger
.
debug
(
"spk_num: {}"
.
format
(
spk_num
))
# frontend
if
lang
==
'zh'
:
...
...
@@ -180,7 +180,7 @@ class TTSServerExecutor(TTSExecutor):
elif
lang
==
'en'
:
self
.
frontend
=
English
(
phone_vocab_path
=
self
.
phones_dict
)
logger
.
info
(
"frontend done!"
)
logger
.
debug
(
"frontend done!"
)
# Create am predictor
self
.
am_predictor_conf
=
am_predictor_conf
...
...
@@ -188,7 +188,7 @@ class TTSServerExecutor(TTSExecutor):
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
logger
.
info
(
"Create AM predictor successfully."
)
logger
.
debug
(
"Create AM predictor successfully."
)
# Create voc predictor
self
.
voc_predictor_conf
=
voc_predictor_conf
...
...
@@ -196,7 +196,7 @@ class TTSServerExecutor(TTSExecutor):
model_file
=
self
.
voc_model
,
params_file
=
self
.
voc_params
,
predictor_conf
=
self
.
voc_predictor_conf
)
logger
.
info
(
"Create Vocoder predictor successfully."
)
logger
.
debug
(
"Create Vocoder predictor successfully."
)
@
paddle
.
no_grad
()
def
infer
(
self
,
...
...
@@ -328,7 +328,8 @@ class TTSEngine(BaseEngine):
logger
.
error
(
e
)
return
False
logger
.
info
(
"Initialize TTS server engine successfully."
)
logger
.
info
(
"Initialize TTS server engine successfully on device: %s."
%
(
self
.
device
))
return
True
...
...
@@ -340,7 +341,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
tts_engine (TTSEngine): The TTS engine
"""
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleTTSConnectionHandler to process the tts request"
)
self
.
tts_engine
=
tts_engine
...
...
@@ -378,23 +379,23 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
if
target_fs
==
0
or
target_fs
>
original_fs
:
target_fs
=
original_fs
wav_tar_fs
=
wav
logger
.
info
(
logger
.
debug
(
"The sample rate of synthesized audio is the same as model, which is {}Hz"
.
format
(
original_fs
))
else
:
wav_tar_fs
=
librosa
.
resample
(
np
.
squeeze
(
wav
),
original_fs
,
target_fs
)
logger
.
info
(
logger
.
debug
(
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully."
.
format
(
original_fs
,
target_fs
))
# transform volume
wav_vol
=
wav_tar_fs
*
volume
logger
.
info
(
"Transform the volume of the audio successfully."
)
logger
.
debug
(
"Transform the volume of the audio successfully."
)
# transform speed
try
:
# windows not support soxbindings
wav_speed
=
change_speed
(
wav_vol
,
speed
,
target_fs
)
logger
.
info
(
"Transform the speed of the audio successfully."
)
logger
.
debug
(
"Transform the speed of the audio successfully."
)
except
ServerBaseException
:
raise
ServerBaseException
(
ErrorCode
.
SERVER_INTERNAL_ERR
,
...
...
@@ -411,7 +412,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
wavfile
.
write
(
buf
,
target_fs
,
wav_speed
)
base64_bytes
=
base64
.
b64encode
(
buf
.
read
())
wav_base64
=
base64_bytes
.
decode
(
'utf-8'
)
logger
.
info
(
"Audio to string successfully."
)
logger
.
debug
(
"Audio to string successfully."
)
# save audio
if
audio_path
is
not
None
:
...
...
@@ -499,15 +500,15 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
logger
.
error
(
e
)
sys
.
exit
(
-
1
)
logger
.
info
(
"AM model: {}"
.
format
(
self
.
config
.
am
))
logger
.
info
(
"Vocoder model: {}"
.
format
(
self
.
config
.
voc
))
logger
.
info
(
"Language: {}"
.
format
(
lang
))
logger
.
debug
(
"AM model: {}"
.
format
(
self
.
config
.
am
))
logger
.
debug
(
"Vocoder model: {}"
.
format
(
self
.
config
.
voc
))
logger
.
debug
(
"Language: {}"
.
format
(
lang
))
logger
.
info
(
"tts engine type: python"
)
logger
.
info
(
"audio duration: {}"
.
format
(
duration
))
logger
.
info
(
"frontend inference time: {}"
.
format
(
self
.
frontend_time
))
logger
.
info
(
"AM inference time: {}"
.
format
(
self
.
am_time
))
logger
.
info
(
"Vocoder inference time: {}"
.
format
(
self
.
voc_time
))
logger
.
debug
(
"frontend inference time: {}"
.
format
(
self
.
frontend_time
))
logger
.
debug
(
"AM inference time: {}"
.
format
(
self
.
am_time
))
logger
.
debug
(
"Vocoder inference time: {}"
.
format
(
self
.
voc_time
))
logger
.
info
(
"total inference time: {}"
.
format
(
infer_time
))
logger
.
info
(
"postprocess (change speed, volume, target sample rate) time: {}"
.
...
...
@@ -515,6 +516,6 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
logger
.
info
(
"total generate audio time: {}"
.
format
(
infer_time
+
postprocess_time
))
logger
.
info
(
"RTF: {}"
.
format
(
rtf
))
logger
.
info
(
"device: {}"
.
format
(
self
.
tts_engine
.
device
))
logger
.
debug
(
"device: {}"
.
format
(
self
.
tts_engine
.
device
))
return
lang
,
target_sample_rate
,
duration
,
wav_base64
paddlespeech/server/engine/tts/python/tts_engine.py
浏览文件 @
9c4763ec
...
...
@@ -105,7 +105,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
tts_engine (TTSEngine): The TTS engine
"""
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleTTSConnectionHandler to process the tts request"
)
self
.
tts_engine
=
tts_engine
...
...
@@ -143,23 +143,23 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
if
target_fs
==
0
or
target_fs
>
original_fs
:
target_fs
=
original_fs
wav_tar_fs
=
wav
logger
.
info
(
logger
.
debug
(
"The sample rate of synthesized audio is the same as model, which is {}Hz"
.
format
(
original_fs
))
else
:
wav_tar_fs
=
librosa
.
resample
(
np
.
squeeze
(
wav
),
original_fs
,
target_fs
)
logger
.
info
(
logger
.
debug
(
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully."
.
format
(
original_fs
,
target_fs
))
# transform volume
wav_vol
=
wav_tar_fs
*
volume
logger
.
info
(
"Transform the volume of the audio successfully."
)
logger
.
debug
(
"Transform the volume of the audio successfully."
)
# transform speed
try
:
# windows not support soxbindings
wav_speed
=
change_speed
(
wav_vol
,
speed
,
target_fs
)
logger
.
info
(
"Transform the speed of the audio successfully."
)
logger
.
debug
(
"Transform the speed of the audio successfully."
)
except
ServerBaseException
:
raise
ServerBaseException
(
ErrorCode
.
SERVER_INTERNAL_ERR
,
...
...
@@ -176,7 +176,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
wavfile
.
write
(
buf
,
target_fs
,
wav_speed
)
base64_bytes
=
base64
.
b64encode
(
buf
.
read
())
wav_base64
=
base64_bytes
.
decode
(
'utf-8'
)
logger
.
info
(
"Audio to string successfully."
)
logger
.
debug
(
"Audio to string successfully."
)
# save audio
if
audio_path
is
not
None
:
...
...
@@ -264,15 +264,15 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
logger
.
error
(
e
)
sys
.
exit
(
-
1
)
logger
.
info
(
"AM model: {}"
.
format
(
self
.
config
.
am
))
logger
.
info
(
"Vocoder model: {}"
.
format
(
self
.
config
.
voc
))
logger
.
info
(
"Language: {}"
.
format
(
lang
))
logger
.
debug
(
"AM model: {}"
.
format
(
self
.
config
.
am
))
logger
.
debug
(
"Vocoder model: {}"
.
format
(
self
.
config
.
voc
))
logger
.
debug
(
"Language: {}"
.
format
(
lang
))
logger
.
info
(
"tts engine type: python"
)
logger
.
info
(
"audio duration: {}"
.
format
(
duration
))
logger
.
info
(
"frontend inference time: {}"
.
format
(
self
.
frontend_time
))
logger
.
info
(
"AM inference time: {}"
.
format
(
self
.
am_time
))
logger
.
info
(
"Vocoder inference time: {}"
.
format
(
self
.
voc_time
))
logger
.
debug
(
"frontend inference time: {}"
.
format
(
self
.
frontend_time
))
logger
.
debug
(
"AM inference time: {}"
.
format
(
self
.
am_time
))
logger
.
debug
(
"Vocoder inference time: {}"
.
format
(
self
.
voc_time
))
logger
.
info
(
"total inference time: {}"
.
format
(
infer_time
))
logger
.
info
(
"postprocess (change speed, volume, target sample rate) time: {}"
.
...
...
@@ -280,6 +280,6 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
logger
.
info
(
"total generate audio time: {}"
.
format
(
infer_time
+
postprocess_time
))
logger
.
info
(
"RTF: {}"
.
format
(
rtf
))
logger
.
info
(
"device: {}"
.
format
(
self
.
tts_engine
.
device
))
logger
.
debug
(
"device: {}"
.
format
(
self
.
tts_engine
.
device
))
return
lang
,
target_sample_rate
,
duration
,
wav_base64
paddlespeech/server/engine/vector/python/vector_engine.py
浏览文件 @
9c4763ec
...
...
@@ -33,7 +33,7 @@ class PaddleVectorConnectionHandler:
vector_engine (VectorEngine): The Vector engine
"""
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleVectorConnectionHandler to process the vector request"
)
self
.
vector_engine
=
vector_engine
self
.
executor
=
self
.
vector_engine
.
executor
...
...
@@ -54,7 +54,7 @@ class PaddleVectorConnectionHandler:
Returns:
str: the punctuation text
"""
logger
.
info
(
logger
.
debug
(
f
"start to extract the do vector
{
self
.
task
}
from the http request"
)
if
self
.
task
==
"spk"
and
task
==
"spk"
:
embedding
=
self
.
extract_audio_embedding
(
audio_data
)
...
...
@@ -81,17 +81,17 @@ class PaddleVectorConnectionHandler:
Returns:
float: the score between enroll and test audio
"""
logger
.
info
(
"start to extract the enroll audio embedding"
)
logger
.
debug
(
"start to extract the enroll audio embedding"
)
enroll_emb
=
self
.
extract_audio_embedding
(
enroll_audio
)
logger
.
info
(
"start to extract the test audio embedding"
)
logger
.
debug
(
"start to extract the test audio embedding"
)
test_emb
=
self
.
extract_audio_embedding
(
test_audio
)
logger
.
info
(
logger
.
debug
(
"start to get the score between the enroll and test embedding"
)
score
=
self
.
executor
.
get_embeddings_score
(
enroll_emb
,
test_emb
)
logger
.
info
(
f
"get the enroll vs test score:
{
score
}
"
)
logger
.
debug
(
f
"get the enroll vs test score:
{
score
}
"
)
return
score
@
paddle
.
no_grad
()
...
...
@@ -106,11 +106,12 @@ class PaddleVectorConnectionHandler:
# because the soundfile will change the io.BytesIO(audio) to the end
# thus we should convert the base64 string to io.BytesIO when we need the audio data
if
not
self
.
executor
.
_check
(
io
.
BytesIO
(
audio
),
sample_rate
):
logger
.
info
(
"check the audio sample rate occurs error"
)
logger
.
debug
(
"check the audio sample rate occurs error"
)
return
np
.
array
([
0.0
])
waveform
,
sr
=
load_audio
(
io
.
BytesIO
(
audio
))
logger
.
info
(
f
"load the audio sample points, shape is:
{
waveform
.
shape
}
"
)
logger
.
debug
(
f
"load the audio sample points, shape is:
{
waveform
.
shape
}
"
)
# stage 2: get the audio feat
# Note: Now we only support fbank feature
...
...
@@ -121,9 +122,9 @@ class PaddleVectorConnectionHandler:
n_mels
=
self
.
config
.
n_mels
,
window_size
=
self
.
config
.
window_size
,
hop_length
=
self
.
config
.
hop_size
)
logger
.
info
(
f
"extract the audio feats, shape is:
{
feats
.
shape
}
"
)
logger
.
debug
(
f
"extract the audio feats, shape is:
{
feats
.
shape
}
"
)
except
Exception
as
e
:
logger
.
info
(
f
"feats occurs exception
{
e
}
"
)
logger
.
error
(
f
"feats occurs exception
{
e
}
"
)
sys
.
exit
(
-
1
)
feats
=
paddle
.
to_tensor
(
feats
).
unsqueeze
(
0
)
...
...
@@ -159,7 +160,7 @@ class VectorEngine(BaseEngine):
"""The Vector Engine
"""
super
(
VectorEngine
,
self
).
__init__
()
logger
.
info
(
"Create the VectorEngine Instance"
)
logger
.
debug
(
"Create the VectorEngine Instance"
)
def
init
(
self
,
config
:
dict
):
"""Init the Vector Engine
...
...
@@ -170,7 +171,7 @@ class VectorEngine(BaseEngine):
Returns:
bool: The engine instance flag
"""
logger
.
info
(
"Init the vector engine"
)
logger
.
debug
(
"Init the vector engine"
)
try
:
self
.
config
=
config
if
self
.
config
.
device
:
...
...
@@ -179,7 +180,7 @@ class VectorEngine(BaseEngine):
self
.
device
=
paddle
.
get_device
()
paddle
.
set_device
(
self
.
device
)
logger
.
info
(
f
"Vector Engine set the device:
{
self
.
device
}
"
)
logger
.
debug
(
f
"Vector Engine set the device:
{
self
.
device
}
"
)
except
BaseException
as
e
:
logger
.
error
(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
...
...
@@ -196,5 +197,7 @@ class VectorEngine(BaseEngine):
ckpt_path
=
config
.
ckpt_path
,
task
=
config
.
task
)
logger
.
info
(
"Init the Vector engine successfully"
)
logger
.
info
(
"Initialize Vector server engine successfully on device: %s."
%
(
self
.
device
))
return
True
paddlespeech/server/utils/audio_handler.py
浏览文件 @
9c4763ec
...
...
@@ -138,7 +138,7 @@ class ASRWsAudioHandler:
Returns:
str: the final asr result
"""
logging
.
info
(
"send a message to the server"
)
logging
.
debug
(
"send a message to the server"
)
if
self
.
url
is
None
:
logger
.
error
(
"No asr server, please input valid ip and port"
)
...
...
@@ -160,7 +160,7 @@ class ASRWsAudioHandler:
separators
=
(
','
,
': '
))
await
ws
.
send
(
audio_info
)
msg
=
await
ws
.
recv
()
logger
.
info
(
"client receive msg={}"
.
format
(
msg
))
logger
.
debug
(
"client receive msg={}"
.
format
(
msg
))
# 3. send chunk audio data to engine
for
chunk_data
in
self
.
read_wave
(
wavfile_path
):
...
...
@@ -170,7 +170,7 @@ class ASRWsAudioHandler:
if
self
.
punc_server
and
len
(
msg
[
"result"
])
>
0
:
msg
[
"result"
]
=
self
.
punc_server
.
run
(
msg
[
"result"
])
logger
.
info
(
"client receive msg={}"
.
format
(
msg
))
logger
.
debug
(
"client receive msg={}"
.
format
(
msg
))
# 4. we must send finished signal to the server
audio_info
=
json
.
dumps
(
...
...
@@ -310,7 +310,7 @@ class TTSWsHandler:
start_request
=
json
.
dumps
({
"task"
:
"tts"
,
"signal"
:
"start"
})
await
ws
.
send
(
start_request
)
msg
=
await
ws
.
recv
()
logger
.
info
(
f
"client receive msg=
{
msg
}
"
)
logger
.
debug
(
f
"client receive msg=
{
msg
}
"
)
msg
=
json
.
loads
(
msg
)
session
=
msg
[
"session"
]
...
...
@@ -319,7 +319,7 @@ class TTSWsHandler:
request
=
json
.
dumps
({
"text"
:
text_base64
})
st
=
time
.
time
()
await
ws
.
send
(
request
)
logging
.
info
(
"send a message to the server"
)
logging
.
debug
(
"send a message to the server"
)
# 4. Process the received response
message
=
await
ws
.
recv
()
...
...
@@ -543,7 +543,6 @@ class VectorHttpHandler:
"sample_rate"
:
sample_rate
,
}
logger
.
info
(
self
.
url
)
res
=
requests
.
post
(
url
=
self
.
url
,
data
=
json
.
dumps
(
data
))
return
res
.
json
()
...
...
paddlespeech/server/utils/audio_process.py
浏览文件 @
9c4763ec
...
...
@@ -169,7 +169,7 @@ def save_audio(bytes_data, audio_path, sample_rate: int=24000) -> bool:
sample_rate
=
sample_rate
)
os
.
remove
(
"./tmp.pcm"
)
else
:
print
(
"Only supports saved audio format is pcm or wav"
)
logger
.
error
(
"Only supports saved audio format is pcm or wav"
)
return
False
return
True
paddlespeech/server/utils/onnx_infer.py
浏览文件 @
9c4763ec
...
...
@@ -20,7 +20,7 @@ from paddlespeech.cli.log import logger
def
get_sess
(
model_path
:
Optional
[
os
.
PathLike
]
=
None
,
sess_conf
:
dict
=
None
):
logger
.
info
(
f
"ort sessconf:
{
sess_conf
}
"
)
logger
.
debug
(
f
"ort sessconf:
{
sess_conf
}
"
)
sess_options
=
ort
.
SessionOptions
()
sess_options
.
graph_optimization_level
=
ort
.
GraphOptimizationLevel
.
ORT_ENABLE_ALL
if
sess_conf
.
get
(
'graph_optimization_level'
,
99
)
==
0
:
...
...
@@ -34,7 +34,7 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
# fastspeech2/mb_melgan can't use trt now!
if
sess_conf
.
get
(
"use_trt"
,
0
):
providers
=
[
'TensorrtExecutionProvider'
]
logger
.
info
(
f
"ort providers:
{
providers
}
"
)
logger
.
debug
(
f
"ort providers:
{
providers
}
"
)
if
'cpu_threads'
in
sess_conf
:
sess_options
.
intra_op_num_threads
=
sess_conf
.
get
(
"cpu_threads"
,
0
)
...
...
paddlespeech/server/utils/util.py
浏览文件 @
9c4763ec
...
...
@@ -13,6 +13,8 @@
import
base64
import
math
from
paddlespeech.cli.log
import
logger
def
wav2base64
(
wav_file
:
str
):
"""
...
...
@@ -61,7 +63,7 @@ def get_chunks(data, block_size, pad_size, step):
elif
step
==
"voc"
:
data_len
=
data
.
shape
[
0
]
else
:
print
(
"Please set correct type to get chunks, am or voc"
)
logger
.
error
(
"Please set correct type to get chunks, am or voc"
)
chunks
=
[]
n
=
math
.
ceil
(
data_len
/
block_size
)
...
...
@@ -73,7 +75,7 @@ def get_chunks(data, block_size, pad_size, step):
elif
step
==
"voc"
:
chunks
.
append
(
data
[
start
:
end
,
:])
else
:
print
(
"Please set correct type to get chunks, am or voc"
)
logger
.
error
(
"Please set correct type to get chunks, am or voc"
)
return
chunks
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录