Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
9c4763ec
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9c4763ec
编写于
7月 05, 2022
作者:
小湉湉
提交者:
GitHub
7月 05, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2113 from yt605155624/rm_server_log
[server]log redundancy in server
上级
e4a8e153
4b1f82d3
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
226 addition
and
223 deletion
+226
-223
paddlespeech/cli/tts/infer.py
paddlespeech/cli/tts/infer.py
+1
-1
paddlespeech/server/bin/paddlespeech_client.py
paddlespeech/server/bin/paddlespeech_client.py
+0
-2
paddlespeech/server/engine/acs/python/acs_engine.py
paddlespeech/server/engine/acs/python/acs_engine.py
+11
-9
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
+22
-22
paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
...ch/server/engine/asr/online/paddleinference/asr_engine.py
+25
-23
paddlespeech/server/engine/asr/online/python/asr_engine.py
paddlespeech/server/engine/asr/online/python/asr_engine.py
+44
-42
paddlespeech/server/engine/asr/paddleinference/asr_engine.py
paddlespeech/server/engine/asr/paddleinference/asr_engine.py
+7
-7
paddlespeech/server/engine/asr/python/asr_engine.py
paddlespeech/server/engine/asr/python/asr_engine.py
+2
-2
paddlespeech/server/engine/cls/paddleinference/cls_engine.py
paddlespeech/server/engine/cls/paddleinference/cls_engine.py
+11
-10
paddlespeech/server/engine/cls/python/cls_engine.py
paddlespeech/server/engine/cls/python/cls_engine.py
+2
-2
paddlespeech/server/engine/engine_warmup.py
paddlespeech/server/engine/engine_warmup.py
+3
-3
paddlespeech/server/engine/text/python/text_engine.py
paddlespeech/server/engine/text/python/text_engine.py
+6
-5
paddlespeech/server/engine/tts/online/onnx/tts_engine.py
paddlespeech/server/engine/tts/online/onnx/tts_engine.py
+12
-14
paddlespeech/server/engine/tts/online/python/tts_engine.py
paddlespeech/server/engine/tts/online/python/tts_engine.py
+9
-15
paddlespeech/server/engine/tts/paddleinference/tts_engine.py
paddlespeech/server/engine/tts/paddleinference/tts_engine.py
+29
-28
paddlespeech/server/engine/tts/python/tts_engine.py
paddlespeech/server/engine/tts/python/tts_engine.py
+13
-13
paddlespeech/server/engine/vector/python/vector_engine.py
paddlespeech/server/engine/vector/python/vector_engine.py
+17
-14
paddlespeech/server/utils/audio_handler.py
paddlespeech/server/utils/audio_handler.py
+5
-6
paddlespeech/server/utils/audio_process.py
paddlespeech/server/utils/audio_process.py
+1
-1
paddlespeech/server/utils/onnx_infer.py
paddlespeech/server/utils/onnx_infer.py
+2
-2
paddlespeech/server/utils/util.py
paddlespeech/server/utils/util.py
+4
-2
未找到文件。
paddlespeech/cli/tts/infer.py
浏览文件 @
9c4763ec
...
@@ -382,7 +382,7 @@ class TTSExecutor(BaseExecutor):
...
@@ -382,7 +382,7 @@ class TTSExecutor(BaseExecutor):
text
,
merge_sentences
=
merge_sentences
)
text
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
phone_ids
=
input_ids
[
"phone_ids"
]
else
:
else
:
print
(
"lang should in {'zh', 'en'}!"
)
logger
.
error
(
"lang should in {'zh', 'en'}!"
)
self
.
frontend_time
=
time
.
time
()
-
frontend_st
self
.
frontend_time
=
time
.
time
()
-
frontend_st
self
.
am_time
=
0
self
.
am_time
=
0
...
...
paddlespeech/server/bin/paddlespeech_client.py
浏览文件 @
9c4763ec
...
@@ -123,7 +123,6 @@ class TTSClientExecutor(BaseExecutor):
...
@@ -123,7 +123,6 @@ class TTSClientExecutor(BaseExecutor):
time_end
=
time
.
time
()
time_end
=
time
.
time
()
time_consume
=
time_end
-
time_start
time_consume
=
time_end
-
time_start
response_dict
=
res
.
json
()
response_dict
=
res
.
json
()
logger
.
info
(
response_dict
[
"message"
])
logger
.
info
(
"Save synthesized audio successfully on %s."
%
(
output
))
logger
.
info
(
"Save synthesized audio successfully on %s."
%
(
output
))
logger
.
info
(
"Audio duration: %f s."
%
logger
.
info
(
"Audio duration: %f s."
%
(
response_dict
[
'result'
][
'duration'
]))
(
response_dict
[
'result'
][
'duration'
]))
...
@@ -702,7 +701,6 @@ class VectorClientExecutor(BaseExecutor):
...
@@ -702,7 +701,6 @@ class VectorClientExecutor(BaseExecutor):
test_audio
=
args
.
test
,
test_audio
=
args
.
test
,
task
=
task
)
task
=
task
)
time_end
=
time
.
time
()
time_end
=
time
.
time
()
logger
.
info
(
f
"The vector:
{
res
}
"
)
logger
.
info
(
"Response time %f s."
%
(
time_end
-
time_start
))
logger
.
info
(
"Response time %f s."
%
(
time_end
-
time_start
))
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
...
...
paddlespeech/server/engine/acs/python/acs_engine.py
浏览文件 @
9c4763ec
...
@@ -30,7 +30,7 @@ class ACSEngine(BaseEngine):
...
@@ -30,7 +30,7 @@ class ACSEngine(BaseEngine):
"""The ACSEngine Engine
"""The ACSEngine Engine
"""
"""
super
(
ACSEngine
,
self
).
__init__
()
super
(
ACSEngine
,
self
).
__init__
()
logger
.
info
(
"Create the ACSEngine Instance"
)
logger
.
debug
(
"Create the ACSEngine Instance"
)
self
.
word_list
=
[]
self
.
word_list
=
[]
def
init
(
self
,
config
:
dict
):
def
init
(
self
,
config
:
dict
):
...
@@ -42,7 +42,7 @@ class ACSEngine(BaseEngine):
...
@@ -42,7 +42,7 @@ class ACSEngine(BaseEngine):
Returns:
Returns:
bool: The engine instance flag
bool: The engine instance flag
"""
"""
logger
.
info
(
"Init the acs engine"
)
logger
.
debug
(
"Init the acs engine"
)
try
:
try
:
self
.
config
=
config
self
.
config
=
config
self
.
device
=
self
.
config
.
get
(
"device"
,
paddle
.
get_device
())
self
.
device
=
self
.
config
.
get
(
"device"
,
paddle
.
get_device
())
...
@@ -50,7 +50,7 @@ class ACSEngine(BaseEngine):
...
@@ -50,7 +50,7 @@ class ACSEngine(BaseEngine):
# websocket default ping timeout is 20 seconds
# websocket default ping timeout is 20 seconds
self
.
ping_timeout
=
self
.
config
.
get
(
"ping_timeout"
,
20
)
self
.
ping_timeout
=
self
.
config
.
get
(
"ping_timeout"
,
20
)
paddle
.
set_device
(
self
.
device
)
paddle
.
set_device
(
self
.
device
)
logger
.
info
(
f
"ACS Engine set the device:
{
self
.
device
}
"
)
logger
.
debug
(
f
"ACS Engine set the device:
{
self
.
device
}
"
)
except
BaseException
as
e
:
except
BaseException
as
e
:
logger
.
error
(
logger
.
error
(
...
@@ -66,7 +66,9 @@ class ACSEngine(BaseEngine):
...
@@ -66,7 +66,9 @@ class ACSEngine(BaseEngine):
self
.
url
=
"ws://"
+
self
.
config
.
asr_server_ip
+
":"
+
str
(
self
.
url
=
"ws://"
+
self
.
config
.
asr_server_ip
+
":"
+
str
(
self
.
config
.
asr_server_port
)
+
"/paddlespeech/asr/streaming"
self
.
config
.
asr_server_port
)
+
"/paddlespeech/asr/streaming"
logger
.
info
(
"Init the acs engine successfully"
)
logger
.
info
(
"Initialize acs server engine successfully on device: %s."
%
(
self
.
device
))
return
True
return
True
def
read_search_words
(
self
):
def
read_search_words
(
self
):
...
@@ -95,12 +97,12 @@ class ACSEngine(BaseEngine):
...
@@ -95,12 +97,12 @@ class ACSEngine(BaseEngine):
Returns:
Returns:
_type_: _description_
_type_: _description_
"""
"""
logger
.
info
(
"send a message to the server"
)
logger
.
debug
(
"send a message to the server"
)
if
self
.
url
is
None
:
if
self
.
url
is
None
:
logger
.
error
(
"No asr server, please input valid ip and port"
)
logger
.
error
(
"No asr server, please input valid ip and port"
)
return
""
return
""
ws
=
websocket
.
WebSocket
()
ws
=
websocket
.
WebSocket
()
logger
.
info
(
f
"set the ping timeout:
{
self
.
ping_timeout
}
seconds"
)
logger
.
debug
(
f
"set the ping timeout:
{
self
.
ping_timeout
}
seconds"
)
ws
.
connect
(
self
.
url
,
ping_timeout
=
self
.
ping_timeout
)
ws
.
connect
(
self
.
url
,
ping_timeout
=
self
.
ping_timeout
)
audio_info
=
json
.
dumps
(
audio_info
=
json
.
dumps
(
{
{
...
@@ -123,7 +125,7 @@ class ACSEngine(BaseEngine):
...
@@ -123,7 +125,7 @@ class ACSEngine(BaseEngine):
logger
.
info
(
f
"audio result:
{
msg
}
"
)
logger
.
info
(
f
"audio result:
{
msg
}
"
)
# 3. send chunk audio data to engine
# 3. send chunk audio data to engine
logger
.
info
(
"send the end signal"
)
logger
.
debug
(
"send the end signal"
)
audio_info
=
json
.
dumps
(
audio_info
=
json
.
dumps
(
{
{
"name"
:
"test.wav"
,
"name"
:
"test.wav"
,
...
@@ -197,7 +199,7 @@ class ACSEngine(BaseEngine):
...
@@ -197,7 +199,7 @@ class ACSEngine(BaseEngine):
start
=
max
(
time_stamp
[
m
.
start
(
0
)][
'bg'
]
-
offset
,
0
)
start
=
max
(
time_stamp
[
m
.
start
(
0
)][
'bg'
]
-
offset
,
0
)
end
=
min
(
time_stamp
[
m
.
end
(
0
)
-
1
][
'ed'
]
+
offset
,
max_ed
)
end
=
min
(
time_stamp
[
m
.
end
(
0
)
-
1
][
'ed'
]
+
offset
,
max_ed
)
logger
.
info
(
f
'start:
{
start
}
, end:
{
end
}
'
)
logger
.
debug
(
f
'start:
{
start
}
, end:
{
end
}
'
)
acs_result
.
append
({
'w'
:
w
,
'bg'
:
start
,
'ed'
:
end
})
acs_result
.
append
({
'w'
:
w
,
'bg'
:
start
,
'ed'
:
end
})
return
acs_result
,
asr_result
return
acs_result
,
asr_result
...
@@ -212,7 +214,7 @@ class ACSEngine(BaseEngine):
...
@@ -212,7 +214,7 @@ class ACSEngine(BaseEngine):
Returns:
Returns:
acs_result, asr_result: the acs result and the asr result
acs_result, asr_result: the acs result and the asr result
"""
"""
logger
.
info
(
"start to process the audio content search"
)
logger
.
debug
(
"start to process the audio content search"
)
msg
=
self
.
get_asr_content
(
io
.
BytesIO
(
audio_data
))
msg
=
self
.
get_asr_content
(
io
.
BytesIO
(
audio_data
))
acs_result
,
asr_result
=
self
.
get_macthed_word
(
msg
)
acs_result
,
asr_result
=
self
.
get_macthed_word
(
msg
)
...
...
paddlespeech/server/engine/asr/online/onnx/asr_engine.py
浏览文件 @
9c4763ec
...
@@ -44,7 +44,7 @@ class PaddleASRConnectionHanddler:
...
@@ -44,7 +44,7 @@ class PaddleASRConnectionHanddler:
asr_engine (ASREngine): the global asr engine
asr_engine (ASREngine): the global asr engine
"""
"""
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"create an paddle asr connection handler to process the websocket connection"
"create an paddle asr connection handler to process the websocket connection"
)
)
self
.
config
=
asr_engine
.
config
# server config
self
.
config
=
asr_engine
.
config
# server config
...
@@ -152,12 +152,12 @@ class PaddleASRConnectionHanddler:
...
@@ -152,12 +152,12 @@ class PaddleASRConnectionHanddler:
self
.
output_reset
()
self
.
output_reset
()
def
extract_feat
(
self
,
samples
:
ByteString
):
def
extract_feat
(
self
,
samples
:
ByteString
):
logger
.
info
(
"Online ASR extract the feat"
)
logger
.
debug
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
assert
samples
.
ndim
==
1
self
.
num_samples
+=
samples
.
shape
[
0
]
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
logger
.
debug
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
)
...
@@ -168,7 +168,7 @@ class PaddleASRConnectionHanddler:
...
@@ -168,7 +168,7 @@ class PaddleASRConnectionHanddler:
else
:
else
:
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
logger
.
debug
(
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
)
)
...
@@ -202,14 +202,14 @@ class PaddleASRConnectionHanddler:
...
@@ -202,14 +202,14 @@ class PaddleASRConnectionHanddler:
# update remained wav
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
logger
.
debug
(
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
)
)
logger
.
info
(
logger
.
debug
(
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
debug
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
logger
.
debug
(
f
"global frames:
{
self
.
num_frames
}
"
)
def
decode
(
self
,
is_finished
=
False
):
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
"""advance decoding
...
@@ -237,7 +237,7 @@ class PaddleASRConnectionHanddler:
...
@@ -237,7 +237,7 @@ class PaddleASRConnectionHanddler:
return
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
logger
.
debug
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
)
...
@@ -355,7 +355,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -355,7 +355,7 @@ class ASRServerExecutor(ASRExecutor):
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
logger
.
debug
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
self
.
download_lm
(
lm_url
,
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
...
@@ -367,7 +367,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -367,7 +367,7 @@ class ASRServerExecutor(ASRExecutor):
if
"deepspeech2"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
# AM predictor
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
logger
.
debug
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor
=
onnx_infer
.
get_sess
(
self
.
am_predictor
=
onnx_infer
.
get_sess
(
model_path
=
self
.
am_model
,
sess_conf
=
self
.
am_predictor_conf
)
model_path
=
self
.
am_model
,
sess_conf
=
self
.
am_predictor_conf
)
else
:
else
:
...
@@ -400,7 +400,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -400,7 +400,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
num_decoding_left_chunks
=
num_decoding_left_chunks
self
.
num_decoding_left_chunks
=
num_decoding_left_chunks
# conf for paddleinference predictor or onnx
# conf for paddleinference predictor or onnx
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor_conf
=
am_predictor_conf
logger
.
info
(
f
"model_type:
{
self
.
model_type
}
"
)
logger
.
debug
(
f
"model_type:
{
self
.
model_type
}
"
)
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
...
@@ -422,12 +422,11 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -422,12 +422,11 @@ class ASRServerExecutor(ASRExecutor):
# self.res_path, self.task_resource.res_dict[
# self.res_path, self.task_resource.res_dict[
# 'params']) if am_params is None else os.path.abspath(am_params)
# 'params']) if am_params is None else os.path.abspath(am_params)
logger
.
info
(
"Load the pretrained model:"
)
logger
.
debug
(
"Load the pretrained model:"
)
logger
.
info
(
f
" tag =
{
tag
}
"
)
logger
.
debug
(
f
" tag =
{
tag
}
"
)
logger
.
info
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
debug
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
info
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
debug
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
info
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
debug
(
f
" am_model path:
{
self
.
am_model
}
"
)
# logger.info(f" am_params path: {self.am_params}")
#Init body.
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
=
CfgNode
(
new_allowed
=
True
)
...
@@ -436,7 +435,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -436,7 +435,7 @@ class ASRServerExecutor(ASRExecutor):
if
self
.
config
.
spm_model_prefix
:
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
logger
.
info
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
logger
.
debug
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
vocab
=
self
.
config
.
vocab_filepath
...
@@ -450,7 +449,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -450,7 +449,7 @@ class ASRServerExecutor(ASRExecutor):
# AM predictor
# AM predictor
self
.
init_model
()
self
.
init_model
()
logger
.
info
(
f
"create the
{
model_type
}
model success"
)
logger
.
debug
(
f
"create the
{
model_type
}
model success"
)
return
True
return
True
...
@@ -501,7 +500,7 @@ class ASREngine(BaseEngine):
...
@@ -501,7 +500,7 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
logger
.
debug
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
if
not
self
.
init_model
():
if
not
self
.
init_model
():
logger
.
error
(
logger
.
error
(
...
@@ -509,7 +508,8 @@ class ASREngine(BaseEngine):
...
@@ -509,7 +508,8 @@ class ASREngine(BaseEngine):
)
)
return
False
return
False
logger
.
info
(
"Initialize ASR server engine successfully."
)
logger
.
info
(
"Initialize ASR server engine successfully on device: %s."
%
(
self
.
device
))
return
True
return
True
def
new_handler
(
self
):
def
new_handler
(
self
):
...
...
paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
浏览文件 @
9c4763ec
...
@@ -44,7 +44,7 @@ class PaddleASRConnectionHanddler:
...
@@ -44,7 +44,7 @@ class PaddleASRConnectionHanddler:
asr_engine (ASREngine): the global asr engine
asr_engine (ASREngine): the global asr engine
"""
"""
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"create an paddle asr connection handler to process the websocket connection"
"create an paddle asr connection handler to process the websocket connection"
)
)
self
.
config
=
asr_engine
.
config
# server config
self
.
config
=
asr_engine
.
config
# server config
...
@@ -157,7 +157,7 @@ class PaddleASRConnectionHanddler:
...
@@ -157,7 +157,7 @@ class PaddleASRConnectionHanddler:
assert
samples
.
ndim
==
1
assert
samples
.
ndim
==
1
self
.
num_samples
+=
samples
.
shape
[
0
]
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
logger
.
debug
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
)
...
@@ -168,7 +168,7 @@ class PaddleASRConnectionHanddler:
...
@@ -168,7 +168,7 @@ class PaddleASRConnectionHanddler:
else
:
else
:
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
logger
.
debug
(
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
)
)
...
@@ -202,14 +202,14 @@ class PaddleASRConnectionHanddler:
...
@@ -202,14 +202,14 @@ class PaddleASRConnectionHanddler:
# update remained wav
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
logger
.
debug
(
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
)
)
logger
.
info
(
logger
.
debug
(
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
debug
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
logger
.
debug
(
f
"global frames:
{
self
.
num_frames
}
"
)
def
decode
(
self
,
is_finished
=
False
):
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
"""advance decoding
...
@@ -237,13 +237,13 @@ class PaddleASRConnectionHanddler:
...
@@ -237,13 +237,13 @@ class PaddleASRConnectionHanddler:
return
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
logger
.
debug
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
)
# the cached feat must be larger decoding_window
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
logger
.
debug
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
)
return
None
,
None
return
None
,
None
...
@@ -294,7 +294,7 @@ class PaddleASRConnectionHanddler:
...
@@ -294,7 +294,7 @@ class PaddleASRConnectionHanddler:
Returns:
Returns:
logprob: poster probability.
logprob: poster probability.
"""
"""
logger
.
info
(
"start to decoce one chunk for deepspeech2"
)
logger
.
debug
(
"start to decoce one chunk for deepspeech2"
)
input_names
=
self
.
am_predictor
.
get_input_names
()
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
...
@@ -369,7 +369,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -369,7 +369,7 @@ class ASRServerExecutor(ASRExecutor):
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
logger
.
debug
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
self
.
download_lm
(
lm_url
,
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
...
@@ -381,7 +381,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -381,7 +381,7 @@ class ASRServerExecutor(ASRExecutor):
if
"deepspeech2"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
# AM predictor
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
logger
.
debug
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor
=
init_predictor
(
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
params_file
=
self
.
am_params
,
...
@@ -415,7 +415,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -415,7 +415,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
num_decoding_left_chunks
=
num_decoding_left_chunks
self
.
num_decoding_left_chunks
=
num_decoding_left_chunks
# conf for paddleinference predictor or onnx
# conf for paddleinference predictor or onnx
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor_conf
=
am_predictor_conf
logger
.
info
(
f
"model_type:
{
self
.
model_type
}
"
)
logger
.
debug
(
f
"model_type:
{
self
.
model_type
}
"
)
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
...
@@ -437,12 +437,12 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -437,12 +437,12 @@ class ASRServerExecutor(ASRExecutor):
self
.
res_path
=
os
.
path
.
dirname
(
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
"Load the pretrained model:"
)
logger
.
debug
(
"Load the pretrained model:"
)
logger
.
info
(
f
" tag =
{
tag
}
"
)
logger
.
debug
(
f
" tag =
{
tag
}
"
)
logger
.
info
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
debug
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
info
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
debug
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
info
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
debug
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
info
(
f
" am_params path:
{
self
.
am_params
}
"
)
logger
.
debug
(
f
" am_params path:
{
self
.
am_params
}
"
)
#Init body.
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
=
CfgNode
(
new_allowed
=
True
)
...
@@ -451,7 +451,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -451,7 +451,7 @@ class ASRServerExecutor(ASRExecutor):
if
self
.
config
.
spm_model_prefix
:
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
logger
.
info
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
logger
.
debug
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
vocab
=
self
.
config
.
vocab_filepath
...
@@ -465,7 +465,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -465,7 +465,7 @@ class ASRServerExecutor(ASRExecutor):
# AM predictor
# AM predictor
self
.
init_model
()
self
.
init_model
()
logger
.
info
(
f
"create the
{
model_type
}
model success"
)
logger
.
debug
(
f
"create the
{
model_type
}
model success"
)
return
True
return
True
...
@@ -516,7 +516,7 @@ class ASREngine(BaseEngine):
...
@@ -516,7 +516,7 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
logger
.
debug
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
if
not
self
.
init_model
():
if
not
self
.
init_model
():
logger
.
error
(
logger
.
error
(
...
@@ -524,7 +524,9 @@ class ASREngine(BaseEngine):
...
@@ -524,7 +524,9 @@ class ASREngine(BaseEngine):
)
)
return
False
return
False
logger
.
info
(
"Initialize ASR server engine successfully."
)
logger
.
info
(
"Initialize ASR server engine successfully on device: %s."
%
(
self
.
device
))
return
True
return
True
def
new_handler
(
self
):
def
new_handler
(
self
):
...
...
paddlespeech/server/engine/asr/online/python/asr_engine.py
浏览文件 @
9c4763ec
...
@@ -49,7 +49,7 @@ class PaddleASRConnectionHanddler:
...
@@ -49,7 +49,7 @@ class PaddleASRConnectionHanddler:
asr_engine (ASREngine): the global asr engine
asr_engine (ASREngine): the global asr engine
"""
"""
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"create an paddle asr connection handler to process the websocket connection"
"create an paddle asr connection handler to process the websocket connection"
)
)
self
.
config
=
asr_engine
.
config
# server config
self
.
config
=
asr_engine
.
config
# server config
...
@@ -107,7 +107,7 @@ class PaddleASRConnectionHanddler:
...
@@ -107,7 +107,7 @@ class PaddleASRConnectionHanddler:
# acoustic model
# acoustic model
self
.
model
=
self
.
asr_engine
.
executor
.
model
self
.
model
=
self
.
asr_engine
.
executor
.
model
self
.
continuous_decoding
=
self
.
config
.
continuous_decoding
self
.
continuous_decoding
=
self
.
config
.
continuous_decoding
logger
.
info
(
f
"continue decoding:
{
self
.
continuous_decoding
}
"
)
logger
.
debug
(
f
"continue decoding:
{
self
.
continuous_decoding
}
"
)
# ctc decoding config
# ctc decoding config
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
...
@@ -207,7 +207,7 @@ class PaddleASRConnectionHanddler:
...
@@ -207,7 +207,7 @@ class PaddleASRConnectionHanddler:
assert
samples
.
ndim
==
1
assert
samples
.
ndim
==
1
self
.
num_samples
+=
samples
.
shape
[
0
]
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
logger
.
debug
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
)
...
@@ -218,7 +218,7 @@ class PaddleASRConnectionHanddler:
...
@@ -218,7 +218,7 @@ class PaddleASRConnectionHanddler:
else
:
else
:
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
logger
.
debug
(
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
f
"The concatenation of remain and now audio samples length is:
{
self
.
remained_wav
.
shape
}
"
)
)
...
@@ -252,14 +252,14 @@ class PaddleASRConnectionHanddler:
...
@@ -252,14 +252,14 @@ class PaddleASRConnectionHanddler:
# update remained wav
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
logger
.
debug
(
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
f
"process the audio feature success, the cached feat shape:
{
self
.
cached_feat
.
shape
}
"
)
)
logger
.
info
(
logger
.
debug
(
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
f
"After extract feat, the cached remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
)
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
debug
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
logger
.
debug
(
f
"global frames:
{
self
.
num_frames
}
"
)
def
decode
(
self
,
is_finished
=
False
):
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
"""advance decoding
...
@@ -283,24 +283,24 @@ class PaddleASRConnectionHanddler:
...
@@ -283,24 +283,24 @@ class PaddleASRConnectionHanddler:
stride
=
subsampling
*
decoding_chunk_size
stride
=
subsampling
*
decoding_chunk_size
if
self
.
cached_feat
is
None
:
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
logger
.
debug
(
"no audio feat, please input more pcm data"
)
return
return
num_frames
=
self
.
cached_feat
.
shape
[
1
]
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
logger
.
debug
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
)
# the cached feat must be larger decoding_window
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
logger
.
debug
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
)
return
None
,
None
return
None
,
None
# if is_finished=True, we need at least context frames
# if is_finished=True, we need at least context frames
if
num_frames
<
context
:
if
num_frames
<
context
:
logger
.
info
(
logger
.
debug
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
)
return
None
,
None
return
None
,
None
...
@@ -354,7 +354,7 @@ class PaddleASRConnectionHanddler:
...
@@ -354,7 +354,7 @@ class PaddleASRConnectionHanddler:
Returns:
Returns:
logprob: poster probability.
logprob: poster probability.
"""
"""
logger
.
info
(
"start to decoce one chunk for deepspeech2"
)
logger
.
debug
(
"start to decoce one chunk for deepspeech2"
)
input_names
=
self
.
am_predictor
.
get_input_names
()
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
...
@@ -391,7 +391,7 @@ class PaddleASRConnectionHanddler:
...
@@ -391,7 +391,7 @@ class PaddleASRConnectionHanddler:
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one best result for deepspeech2:
{
trans_best
[
0
]
}
"
)
logger
.
debug
(
f
"decode one best result for deepspeech2:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
return
trans_best
[
0
]
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
...
@@ -402,7 +402,7 @@ class PaddleASRConnectionHanddler:
...
@@ -402,7 +402,7 @@ class PaddleASRConnectionHanddler:
# reset endpiont state
# reset endpiont state
self
.
endpoint_state
=
False
self
.
endpoint_state
=
False
logger
.
info
(
logger
.
debug
(
"Conformer/Transformer: start to decode with advanced_decoding method"
"Conformer/Transformer: start to decode with advanced_decoding method"
)
)
cfg
=
self
.
ctc_decode_config
cfg
=
self
.
ctc_decode_config
...
@@ -427,25 +427,25 @@ class PaddleASRConnectionHanddler:
...
@@ -427,25 +427,25 @@ class PaddleASRConnectionHanddler:
stride
=
subsampling
*
decoding_chunk_size
stride
=
subsampling
*
decoding_chunk_size
if
self
.
cached_feat
is
None
:
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
logger
.
debug
(
"no audio feat, please input more pcm data"
)
return
return
# (B=1,T,D)
# (B=1,T,D)
num_frames
=
self
.
cached_feat
.
shape
[
1
]
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
logger
.
debug
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
)
# the cached feat must be larger decoding_window
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
logger
.
debug
(
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
f
"frame feat num is less than
{
decoding_window
}
, please input more pcm data"
)
)
return
None
,
None
return
None
,
None
# if is_finished=True, we need at least context frames
# if is_finished=True, we need at least context frames
if
num_frames
<
context
:
if
num_frames
<
context
:
logger
.
info
(
logger
.
debug
(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
)
return
None
,
None
return
None
,
None
...
@@ -489,7 +489,7 @@ class PaddleASRConnectionHanddler:
...
@@ -489,7 +489,7 @@ class PaddleASRConnectionHanddler:
self
.
encoder_out
=
ys
self
.
encoder_out
=
ys
else
:
else
:
self
.
encoder_out
=
paddle
.
concat
([
self
.
encoder_out
,
ys
],
axis
=
1
)
self
.
encoder_out
=
paddle
.
concat
([
self
.
encoder_out
,
ys
],
axis
=
1
)
logger
.
info
(
logger
.
debug
(
f
"This connection handler encoder out shape:
{
self
.
encoder_out
.
shape
}
"
f
"This connection handler encoder out shape:
{
self
.
encoder_out
.
shape
}
"
)
)
...
@@ -513,7 +513,8 @@ class PaddleASRConnectionHanddler:
...
@@ -513,7 +513,8 @@ class PaddleASRConnectionHanddler:
if
self
.
endpointer
.
endpoint_detected
(
ctc_probs
.
numpy
(),
if
self
.
endpointer
.
endpoint_detected
(
ctc_probs
.
numpy
(),
decoding_something
):
decoding_something
):
self
.
endpoint_state
=
True
self
.
endpoint_state
=
True
logger
.
info
(
f
"Endpoint is detected at
{
self
.
num_frames
}
frame."
)
logger
.
debug
(
f
"Endpoint is detected at
{
self
.
num_frames
}
frame."
)
# advance cache of feat
# advance cache of feat
assert
self
.
cached_feat
.
shape
[
0
]
==
1
#(B=1,T,D)
assert
self
.
cached_feat
.
shape
[
0
]
==
1
#(B=1,T,D)
...
@@ -526,7 +527,7 @@ class PaddleASRConnectionHanddler:
...
@@ -526,7 +527,7 @@ class PaddleASRConnectionHanddler:
def
update_result
(
self
):
def
update_result
(
self
):
"""Conformer/Transformer hyps to result.
"""Conformer/Transformer hyps to result.
"""
"""
logger
.
info
(
"update the final result"
)
logger
.
debug
(
"update the final result"
)
hyps
=
self
.
hyps
hyps
=
self
.
hyps
# output results and tokenids
# output results and tokenids
...
@@ -560,16 +561,16 @@ class PaddleASRConnectionHanddler:
...
@@ -560,16 +561,16 @@ class PaddleASRConnectionHanddler:
only for conformer and transformer model.
only for conformer and transformer model.
"""
"""
if
"deepspeech2"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
logger
.
info
(
"deepspeech2 not support rescoring decoding."
)
logger
.
debug
(
"deepspeech2 not support rescoring decoding."
)
return
return
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
logger
.
info
(
logger
.
debug
(
f
"decoding method not match:
{
self
.
ctc_decode_config
.
decoding_method
}
, need attention_rescoring"
f
"decoding method not match:
{
self
.
ctc_decode_config
.
decoding_method
}
, need attention_rescoring"
)
)
return
return
logger
.
info
(
"rescoring the final result"
)
logger
.
debug
(
"rescoring the final result"
)
# last decoding for last audio
# last decoding for last audio
self
.
searcher
.
finalize_search
()
self
.
searcher
.
finalize_search
()
...
@@ -685,7 +686,6 @@ class PaddleASRConnectionHanddler:
...
@@ -685,7 +686,6 @@ class PaddleASRConnectionHanddler:
"bg"
:
global_offset_in_sec
+
start
,
"bg"
:
global_offset_in_sec
+
start
,
"ed"
:
global_offset_in_sec
+
end
"ed"
:
global_offset_in_sec
+
end
})
})
# logger.info(f"{word_time_stamp[-1]}")
self
.
word_time_stamp
=
word_time_stamp
self
.
word_time_stamp
=
word_time_stamp
logger
.
info
(
f
"word time stamp:
{
self
.
word_time_stamp
}
"
)
logger
.
info
(
f
"word time stamp:
{
self
.
word_time_stamp
}
"
)
...
@@ -707,13 +707,13 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -707,13 +707,13 @@ class ASRServerExecutor(ASRExecutor):
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_url
=
self
.
task_resource
.
res_dict
[
'lm_url'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
lm_md5
=
self
.
task_resource
.
res_dict
[
'lm_md5'
]
logger
.
info
(
f
"Start to load language model
{
lm_url
}
"
)
logger
.
debug
(
f
"Start to load language model
{
lm_url
}
"
)
self
.
download_lm
(
self
.
download_lm
(
lm_url
,
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
with
UpdateConfig
(
self
.
config
):
with
UpdateConfig
(
self
.
config
):
logger
.
info
(
"start to create the stream conformer asr engine"
)
logger
.
debug
(
"start to create the stream conformer asr engine"
)
# update the decoding method
# update the decoding method
if
self
.
decode_method
:
if
self
.
decode_method
:
self
.
config
.
decode
.
decoding_method
=
self
.
decode_method
self
.
config
.
decode
.
decoding_method
=
self
.
decode_method
...
@@ -726,7 +726,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -726,7 +726,7 @@ class ASRServerExecutor(ASRExecutor):
if
self
.
config
.
decode
.
decoding_method
not
in
[
if
self
.
config
.
decode
.
decoding_method
not
in
[
"ctc_prefix_beam_search"
,
"attention_rescoring"
"ctc_prefix_beam_search"
,
"attention_rescoring"
]:
]:
logger
.
info
(
logger
.
debug
(
"we set the decoding_method to attention_rescoring"
)
"we set the decoding_method to attention_rescoring"
)
self
.
config
.
decode
.
decoding_method
=
"attention_rescoring"
self
.
config
.
decode
.
decoding_method
=
"attention_rescoring"
...
@@ -739,7 +739,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -739,7 +739,7 @@ class ASRServerExecutor(ASRExecutor):
def
init_model
(
self
)
->
None
:
def
init_model
(
self
)
->
None
:
if
"deepspeech2"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
# AM predictor
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
logger
.
debug
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor
=
init_predictor
(
self
.
am_predictor
=
init_predictor
(
model_file
=
self
.
am_model
,
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
params_file
=
self
.
am_params
,
...
@@ -748,7 +748,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -748,7 +748,7 @@ class ASRServerExecutor(ASRExecutor):
# load model
# load model
# model_type: {model_name}_{dataset}
# model_type: {model_name}_{dataset}
model_name
=
self
.
model_type
[:
self
.
model_type
.
rindex
(
'_'
)]
model_name
=
self
.
model_type
[:
self
.
model_type
.
rindex
(
'_'
)]
logger
.
info
(
f
"model name:
{
model_name
}
"
)
logger
.
debug
(
f
"model name:
{
model_name
}
"
)
model_class
=
self
.
task_resource
.
get_model_class
(
model_name
)
model_class
=
self
.
task_resource
.
get_model_class
(
model_name
)
model
=
model_class
.
from_config
(
self
.
config
)
model
=
model_class
.
from_config
(
self
.
config
)
self
.
model
=
model
self
.
model
=
model
...
@@ -782,7 +782,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -782,7 +782,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
num_decoding_left_chunks
=
num_decoding_left_chunks
self
.
num_decoding_left_chunks
=
num_decoding_left_chunks
# conf for paddleinference predictor or onnx
# conf for paddleinference predictor or onnx
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor_conf
=
am_predictor_conf
logger
.
info
(
f
"model_type:
{
self
.
model_type
}
"
)
logger
.
debug
(
f
"model_type:
{
self
.
model_type
}
"
)
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
...
@@ -804,12 +804,12 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -804,12 +804,12 @@ class ASRServerExecutor(ASRExecutor):
self
.
res_path
=
os
.
path
.
dirname
(
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
cfg_path
)))
logger
.
info
(
"Load the pretrained model:"
)
logger
.
debug
(
"Load the pretrained model:"
)
logger
.
info
(
f
" tag =
{
tag
}
"
)
logger
.
debug
(
f
" tag =
{
tag
}
"
)
logger
.
info
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
debug
(
f
" res_path:
{
self
.
res_path
}
"
)
logger
.
info
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
debug
(
f
" cfg path:
{
self
.
cfg_path
}
"
)
logger
.
info
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
debug
(
f
" am_model path:
{
self
.
am_model
}
"
)
logger
.
info
(
f
" am_params path:
{
self
.
am_params
}
"
)
logger
.
debug
(
f
" am_params path:
{
self
.
am_params
}
"
)
#Init body.
#Init body.
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
=
CfgNode
(
new_allowed
=
True
)
...
@@ -818,7 +818,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -818,7 +818,7 @@ class ASRServerExecutor(ASRExecutor):
if
self
.
config
.
spm_model_prefix
:
if
self
.
config
.
spm_model_prefix
:
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
config
.
spm_model_prefix
=
os
.
path
.
join
(
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
self
.
res_path
,
self
.
config
.
spm_model_prefix
)
logger
.
info
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
logger
.
debug
(
f
"spm model path:
{
self
.
config
.
spm_model_prefix
}
"
)
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
vocab
=
self
.
config
.
vocab_filepath
...
@@ -832,7 +832,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -832,7 +832,7 @@ class ASRServerExecutor(ASRExecutor):
# AM predictor
# AM predictor
self
.
init_model
()
self
.
init_model
()
logger
.
info
(
f
"create the
{
model_type
}
model success"
)
logger
.
debug
(
f
"create the
{
model_type
}
model success"
)
return
True
return
True
...
@@ -883,7 +883,7 @@ class ASREngine(BaseEngine):
...
@@ -883,7 +883,7 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
logger
.
debug
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
if
not
self
.
init_model
():
if
not
self
.
init_model
():
logger
.
error
(
logger
.
error
(
...
@@ -891,7 +891,9 @@ class ASREngine(BaseEngine):
...
@@ -891,7 +891,9 @@ class ASREngine(BaseEngine):
)
)
return
False
return
False
logger
.
info
(
"Initialize ASR server engine successfully."
)
logger
.
info
(
"Initialize ASR server engine successfully on device: %s."
%
(
self
.
device
))
return
True
return
True
def
new_handler
(
self
):
def
new_handler
(
self
):
...
...
paddlespeech/server/engine/asr/paddleinference/asr_engine.py
浏览文件 @
9c4763ec
...
@@ -65,10 +65,10 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -65,10 +65,10 @@ class ASRServerExecutor(ASRExecutor):
self
.
task_resource
.
res_dict
[
'model'
])
self
.
task_resource
.
res_dict
[
'model'
])
self
.
am_params
=
os
.
path
.
join
(
self
.
res_path
,
self
.
am_params
=
os
.
path
.
join
(
self
.
res_path
,
self
.
task_resource
.
res_dict
[
'params'
])
self
.
task_resource
.
res_dict
[
'params'
])
logger
.
info
(
self
.
res_path
)
logger
.
debug
(
self
.
res_path
)
logger
.
info
(
self
.
cfg_path
)
logger
.
debug
(
self
.
cfg_path
)
logger
.
info
(
self
.
am_model
)
logger
.
debug
(
self
.
am_model
)
logger
.
info
(
self
.
am_params
)
logger
.
debug
(
self
.
am_params
)
else
:
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
...
@@ -236,16 +236,16 @@ class PaddleASRConnectionHandler(ASRServerExecutor):
...
@@ -236,16 +236,16 @@ class PaddleASRConnectionHandler(ASRServerExecutor):
if
self
.
_check
(
if
self
.
_check
(
io
.
BytesIO
(
audio_data
),
self
.
asr_engine
.
config
.
sample_rate
,
io
.
BytesIO
(
audio_data
),
self
.
asr_engine
.
config
.
sample_rate
,
self
.
asr_engine
.
config
.
force_yes
):
self
.
asr_engine
.
config
.
force_yes
):
logger
.
info
(
"start running asr engine"
)
logger
.
debug
(
"start running asr engine"
)
self
.
preprocess
(
self
.
asr_engine
.
config
.
model_type
,
self
.
preprocess
(
self
.
asr_engine
.
config
.
model_type
,
io
.
BytesIO
(
audio_data
))
io
.
BytesIO
(
audio_data
))
st
=
time
.
time
()
st
=
time
.
time
()
self
.
infer
(
self
.
asr_engine
.
config
.
model_type
)
self
.
infer
(
self
.
asr_engine
.
config
.
model_type
)
infer_time
=
time
.
time
()
-
st
infer_time
=
time
.
time
()
-
st
self
.
output
=
self
.
postprocess
()
# Retrieve result of asr.
self
.
output
=
self
.
postprocess
()
# Retrieve result of asr.
logger
.
info
(
"end inferring asr engine"
)
logger
.
debug
(
"end inferring asr engine"
)
else
:
else
:
logger
.
info
(
"file check failed!"
)
logger
.
error
(
"file check failed!"
)
self
.
output
=
None
self
.
output
=
None
logger
.
info
(
"inference time: {}"
.
format
(
infer_time
))
logger
.
info
(
"inference time: {}"
.
format
(
infer_time
))
...
...
paddlespeech/server/engine/asr/python/asr_engine.py
浏览文件 @
9c4763ec
...
@@ -104,7 +104,7 @@ class PaddleASRConnectionHandler(ASRServerExecutor):
...
@@ -104,7 +104,7 @@ class PaddleASRConnectionHandler(ASRServerExecutor):
if
self
.
_check
(
if
self
.
_check
(
io
.
BytesIO
(
audio_data
),
self
.
asr_engine
.
config
.
sample_rate
,
io
.
BytesIO
(
audio_data
),
self
.
asr_engine
.
config
.
sample_rate
,
self
.
asr_engine
.
config
.
force_yes
):
self
.
asr_engine
.
config
.
force_yes
):
logger
.
info
(
"start run asr engine"
)
logger
.
debug
(
"start run asr engine"
)
self
.
preprocess
(
self
.
asr_engine
.
config
.
model
,
self
.
preprocess
(
self
.
asr_engine
.
config
.
model
,
io
.
BytesIO
(
audio_data
))
io
.
BytesIO
(
audio_data
))
st
=
time
.
time
()
st
=
time
.
time
()
...
@@ -112,7 +112,7 @@ class PaddleASRConnectionHandler(ASRServerExecutor):
...
@@ -112,7 +112,7 @@ class PaddleASRConnectionHandler(ASRServerExecutor):
infer_time
=
time
.
time
()
-
st
infer_time
=
time
.
time
()
-
st
self
.
output
=
self
.
postprocess
()
# Retrieve result of asr.
self
.
output
=
self
.
postprocess
()
# Retrieve result of asr.
else
:
else
:
logger
.
info
(
"file check failed!"
)
logger
.
error
(
"file check failed!"
)
self
.
output
=
None
self
.
output
=
None
logger
.
info
(
"inference time: {}"
.
format
(
infer_time
))
logger
.
info
(
"inference time: {}"
.
format
(
infer_time
))
...
...
paddlespeech/server/engine/cls/paddleinference/cls_engine.py
浏览文件 @
9c4763ec
...
@@ -67,22 +67,22 @@ class CLSServerExecutor(CLSExecutor):
...
@@ -67,22 +67,22 @@ class CLSServerExecutor(CLSExecutor):
self
.
params_path
=
os
.
path
.
abspath
(
params_path
)
self
.
params_path
=
os
.
path
.
abspath
(
params_path
)
self
.
label_file
=
os
.
path
.
abspath
(
label_file
)
self
.
label_file
=
os
.
path
.
abspath
(
label_file
)
logger
.
info
(
self
.
cfg_path
)
logger
.
debug
(
self
.
cfg_path
)
logger
.
info
(
self
.
model_path
)
logger
.
debug
(
self
.
model_path
)
logger
.
info
(
self
.
params_path
)
logger
.
debug
(
self
.
params_path
)
logger
.
info
(
self
.
label_file
)
logger
.
debug
(
self
.
label_file
)
# config
# config
with
open
(
self
.
cfg_path
,
'r'
)
as
f
:
with
open
(
self
.
cfg_path
,
'r'
)
as
f
:
self
.
_conf
=
yaml
.
safe_load
(
f
)
self
.
_conf
=
yaml
.
safe_load
(
f
)
logger
.
info
(
"Read cfg file successfully."
)
logger
.
debug
(
"Read cfg file successfully."
)
# labels
# labels
self
.
_label_list
=
[]
self
.
_label_list
=
[]
with
open
(
self
.
label_file
,
'r'
)
as
f
:
with
open
(
self
.
label_file
,
'r'
)
as
f
:
for
line
in
f
:
for
line
in
f
:
self
.
_label_list
.
append
(
line
.
strip
())
self
.
_label_list
.
append
(
line
.
strip
())
logger
.
info
(
"Read label file successfully."
)
logger
.
debug
(
"Read label file successfully."
)
# Create predictor
# Create predictor
self
.
predictor_conf
=
predictor_conf
self
.
predictor_conf
=
predictor_conf
...
@@ -90,7 +90,7 @@ class CLSServerExecutor(CLSExecutor):
...
@@ -90,7 +90,7 @@ class CLSServerExecutor(CLSExecutor):
model_file
=
self
.
model_path
,
model_file
=
self
.
model_path
,
params_file
=
self
.
params_path
,
params_file
=
self
.
params_path
,
predictor_conf
=
self
.
predictor_conf
)
predictor_conf
=
self
.
predictor_conf
)
logger
.
info
(
"Create predictor successfully."
)
logger
.
debug
(
"Create predictor successfully."
)
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
infer
(
self
):
def
infer
(
self
):
...
@@ -148,7 +148,8 @@ class CLSEngine(BaseEngine):
...
@@ -148,7 +148,8 @@ class CLSEngine(BaseEngine):
logger
.
error
(
e
)
logger
.
error
(
e
)
return
False
return
False
logger
.
info
(
"Initialize CLS server engine successfully."
)
logger
.
info
(
"Initialize CLS server engine successfully on device: %s."
%
(
self
.
device
))
return
True
return
True
...
@@ -160,7 +161,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
...
@@ -160,7 +161,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
cls_engine (CLSEngine): The CLS engine
cls_engine (CLSEngine): The CLS engine
"""
"""
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleCLSConnectionHandler to process the cls request"
)
"Create PaddleCLSConnectionHandler to process the cls request"
)
self
.
_inputs
=
OrderedDict
()
self
.
_inputs
=
OrderedDict
()
...
@@ -183,7 +184,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
...
@@ -183,7 +184,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
self
.
infer
()
self
.
infer
()
infer_time
=
time
.
time
()
-
st
infer_time
=
time
.
time
()
-
st
logger
.
info
(
"inference time: {}"
.
format
(
infer_time
))
logger
.
debug
(
"inference time: {}"
.
format
(
infer_time
))
logger
.
info
(
"cls engine type: inference"
)
logger
.
info
(
"cls engine type: inference"
)
def
postprocess
(
self
,
topk
:
int
):
def
postprocess
(
self
,
topk
:
int
):
...
...
paddlespeech/server/engine/cls/python/cls_engine.py
浏览文件 @
9c4763ec
...
@@ -88,7 +88,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
...
@@ -88,7 +88,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
cls_engine (CLSEngine): The CLS engine
cls_engine (CLSEngine): The CLS engine
"""
"""
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleCLSConnectionHandler to process the cls request"
)
"Create PaddleCLSConnectionHandler to process the cls request"
)
self
.
_inputs
=
OrderedDict
()
self
.
_inputs
=
OrderedDict
()
...
@@ -110,7 +110,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
...
@@ -110,7 +110,7 @@ class PaddleCLSConnectionHandler(CLSServerExecutor):
self
.
infer
()
self
.
infer
()
infer_time
=
time
.
time
()
-
st
infer_time
=
time
.
time
()
-
st
logger
.
info
(
"inference time: {}"
.
format
(
infer_time
))
logger
.
debug
(
"inference time: {}"
.
format
(
infer_time
))
logger
.
info
(
"cls engine type: python"
)
logger
.
info
(
"cls engine type: python"
)
def
postprocess
(
self
,
topk
:
int
):
def
postprocess
(
self
,
topk
:
int
):
...
...
paddlespeech/server/engine/engine_warmup.py
浏览文件 @
9c4763ec
...
@@ -45,7 +45,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
...
@@ -45,7 +45,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
logger
.
error
(
"Please check tte engine type."
)
logger
.
error
(
"Please check tte engine type."
)
try
:
try
:
logger
.
info
(
"Start to warm up tts engine."
)
logger
.
debug
(
"Start to warm up tts engine."
)
for
i
in
range
(
warm_up_time
):
for
i
in
range
(
warm_up_time
):
connection_handler
=
PaddleTTSConnectionHandler
(
tts_engine
)
connection_handler
=
PaddleTTSConnectionHandler
(
tts_engine
)
if
flag_online
:
if
flag_online
:
...
@@ -53,7 +53,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
...
@@ -53,7 +53,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
text
=
sentence
,
text
=
sentence
,
lang
=
tts_engine
.
lang
,
lang
=
tts_engine
.
lang
,
am
=
tts_engine
.
config
.
am
):
am
=
tts_engine
.
config
.
am
):
logger
.
info
(
logger
.
debug
(
f
"The first response time of the
{
i
}
warm up:
{
connection_handler
.
first_response_time
}
s"
f
"The first response time of the
{
i
}
warm up:
{
connection_handler
.
first_response_time
}
s"
)
)
break
break
...
@@ -62,7 +62,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
...
@@ -62,7 +62,7 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool:
st
=
time
.
time
()
st
=
time
.
time
()
connection_handler
.
infer
(
text
=
sentence
)
connection_handler
.
infer
(
text
=
sentence
)
et
=
time
.
time
()
et
=
time
.
time
()
logger
.
info
(
logger
.
debug
(
f
"The response time of the
{
i
}
warm up:
{
et
-
st
}
s"
)
f
"The response time of the
{
i
}
warm up:
{
et
-
st
}
s"
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
"Failed to warm up on tts engine."
)
logger
.
error
(
"Failed to warm up on tts engine."
)
...
...
paddlespeech/server/engine/text/python/text_engine.py
浏览文件 @
9c4763ec
...
@@ -28,7 +28,7 @@ class PaddleTextConnectionHandler:
...
@@ -28,7 +28,7 @@ class PaddleTextConnectionHandler:
text_engine (TextEngine): The Text engine
text_engine (TextEngine): The Text engine
"""
"""
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleTextConnectionHandler to process the text request"
)
"Create PaddleTextConnectionHandler to process the text request"
)
self
.
text_engine
=
text_engine
self
.
text_engine
=
text_engine
self
.
task
=
self
.
text_engine
.
executor
.
task
self
.
task
=
self
.
text_engine
.
executor
.
task
...
@@ -130,7 +130,7 @@ class TextEngine(BaseEngine):
...
@@ -130,7 +130,7 @@ class TextEngine(BaseEngine):
"""The Text Engine
"""The Text Engine
"""
"""
super
(
TextEngine
,
self
).
__init__
()
super
(
TextEngine
,
self
).
__init__
()
logger
.
info
(
"Create the TextEngine Instance"
)
logger
.
debug
(
"Create the TextEngine Instance"
)
def
init
(
self
,
config
:
dict
):
def
init
(
self
,
config
:
dict
):
"""Init the Text Engine
"""Init the Text Engine
...
@@ -141,7 +141,7 @@ class TextEngine(BaseEngine):
...
@@ -141,7 +141,7 @@ class TextEngine(BaseEngine):
Returns:
Returns:
bool: The engine instance flag
bool: The engine instance flag
"""
"""
logger
.
info
(
"Init the text engine"
)
logger
.
debug
(
"Init the text engine"
)
try
:
try
:
self
.
config
=
config
self
.
config
=
config
if
self
.
config
.
device
:
if
self
.
config
.
device
:
...
@@ -150,7 +150,7 @@ class TextEngine(BaseEngine):
...
@@ -150,7 +150,7 @@ class TextEngine(BaseEngine):
self
.
device
=
paddle
.
get_device
()
self
.
device
=
paddle
.
get_device
()
paddle
.
set_device
(
self
.
device
)
paddle
.
set_device
(
self
.
device
)
logger
.
info
(
f
"Text Engine set the device:
{
self
.
device
}
"
)
logger
.
debug
(
f
"Text Engine set the device:
{
self
.
device
}
"
)
except
BaseException
as
e
:
except
BaseException
as
e
:
logger
.
error
(
logger
.
error
(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
...
@@ -168,5 +168,6 @@ class TextEngine(BaseEngine):
...
@@ -168,5 +168,6 @@ class TextEngine(BaseEngine):
ckpt_path
=
config
.
ckpt_path
,
ckpt_path
=
config
.
ckpt_path
,
vocab_file
=
config
.
vocab_file
)
vocab_file
=
config
.
vocab_file
)
logger
.
info
(
"Init the text engine successfully"
)
logger
.
info
(
"Initialize Text server engine successfully on device: %s."
%
(
self
.
device
))
return
True
return
True
paddlespeech/server/engine/tts/online/onnx/tts_engine.py
浏览文件 @
9c4763ec
...
@@ -62,7 +62,7 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -62,7 +62,7 @@ class TTSServerExecutor(TTSExecutor):
(
hasattr
(
self
,
'am_encoder_infer_sess'
)
and
(
hasattr
(
self
,
'am_encoder_infer_sess'
)
and
hasattr
(
self
,
'am_decoder_sess'
)
and
hasattr
(
hasattr
(
self
,
'am_decoder_sess'
)
and
hasattr
(
self
,
'am_postnet_sess'
)))
and
hasattr
(
self
,
'voc_inference'
):
self
,
'am_postnet_sess'
)))
and
hasattr
(
self
,
'voc_inference'
):
logger
.
info
(
'Models had been initialized.'
)
logger
.
debug
(
'Models had been initialized.'
)
return
return
# am
# am
am_tag
=
am
+
'-'
+
lang
am_tag
=
am
+
'-'
+
lang
...
@@ -85,8 +85,7 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -85,8 +85,7 @@ class TTSServerExecutor(TTSExecutor):
else
:
else
:
self
.
am_ckpt
=
os
.
path
.
abspath
(
am_ckpt
[
0
])
self
.
am_ckpt
=
os
.
path
.
abspath
(
am_ckpt
[
0
])
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
am_res_path
=
os
.
path
.
dirname
(
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
am_ckpt
))
os
.
path
.
abspath
(
am_ckpt
))
# create am sess
# create am sess
self
.
am_sess
=
get_sess
(
self
.
am_ckpt
,
am_sess_conf
)
self
.
am_sess
=
get_sess
(
self
.
am_ckpt
,
am_sess_conf
)
...
@@ -119,8 +118,7 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -119,8 +118,7 @@ class TTSServerExecutor(TTSExecutor):
self
.
am_postnet
=
os
.
path
.
abspath
(
am_ckpt
[
2
])
self
.
am_postnet
=
os
.
path
.
abspath
(
am_ckpt
[
2
])
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
am_stat
=
os
.
path
.
abspath
(
am_stat
)
self
.
am_stat
=
os
.
path
.
abspath
(
am_stat
)
self
.
am_res_path
=
os
.
path
.
dirname
(
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
am_ckpt
[
0
]))
os
.
path
.
abspath
(
am_ckpt
[
0
]))
# create am sess
# create am sess
self
.
am_encoder_infer_sess
=
get_sess
(
self
.
am_encoder_infer
,
self
.
am_encoder_infer_sess
=
get_sess
(
self
.
am_encoder_infer
,
...
@@ -130,13 +128,13 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -130,13 +128,13 @@ class TTSServerExecutor(TTSExecutor):
self
.
am_mu
,
self
.
am_std
=
np
.
load
(
self
.
am_stat
)
self
.
am_mu
,
self
.
am_std
=
np
.
load
(
self
.
am_stat
)
logger
.
info
(
f
"self.phones_dict:
{
self
.
phones_dict
}
"
)
logger
.
debug
(
f
"self.phones_dict:
{
self
.
phones_dict
}
"
)
logger
.
info
(
f
"am model dir:
{
self
.
am_res_path
}
"
)
logger
.
debug
(
f
"am model dir:
{
self
.
am_res_path
}
"
)
logger
.
info
(
"Create am sess successfully."
)
logger
.
debug
(
"Create am sess successfully."
)
# voc model info
# voc model info
voc_tag
=
voc
+
'-'
+
lang
voc_tag
=
voc
+
'-'
+
lang
if
voc_ckpt
is
None
:
if
voc_ckpt
is
None
:
self
.
task_resource
.
set_task_model
(
self
.
task_resource
.
set_task_model
(
model_tag
=
voc_tag
,
model_tag
=
voc_tag
,
...
@@ -149,16 +147,16 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -149,16 +147,16 @@ class TTSServerExecutor(TTSExecutor):
else
:
else
:
self
.
voc_ckpt
=
os
.
path
.
abspath
(
voc_ckpt
)
self
.
voc_ckpt
=
os
.
path
.
abspath
(
voc_ckpt
)
self
.
voc_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
voc_ckpt
))
self
.
voc_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
voc_ckpt
))
logger
.
info
(
self
.
voc_res_path
)
logger
.
debug
(
self
.
voc_res_path
)
# create voc sess
# create voc sess
self
.
voc_sess
=
get_sess
(
self
.
voc_ckpt
,
voc_sess_conf
)
self
.
voc_sess
=
get_sess
(
self
.
voc_ckpt
,
voc_sess_conf
)
logger
.
info
(
"Create voc sess successfully."
)
logger
.
debug
(
"Create voc sess successfully."
)
with
open
(
self
.
phones_dict
,
"r"
)
as
f
:
with
open
(
self
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
self
.
vocab_size
=
len
(
phn_id
)
self
.
vocab_size
=
len
(
phn_id
)
logger
.
info
(
f
"vocab_size:
{
self
.
vocab_size
}
"
)
logger
.
debug
(
f
"vocab_size:
{
self
.
vocab_size
}
"
)
# frontend
# frontend
self
.
tones_dict
=
None
self
.
tones_dict
=
None
...
@@ -169,7 +167,7 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -169,7 +167,7 @@ class TTSServerExecutor(TTSExecutor):
elif
lang
==
'en'
:
elif
lang
==
'en'
:
self
.
frontend
=
English
(
phone_vocab_path
=
self
.
phones_dict
)
self
.
frontend
=
English
(
phone_vocab_path
=
self
.
phones_dict
)
logger
.
info
(
"frontend done!"
)
logger
.
debug
(
"frontend done!"
)
class
TTSEngine
(
BaseEngine
):
class
TTSEngine
(
BaseEngine
):
...
@@ -267,7 +265,7 @@ class PaddleTTSConnectionHandler:
...
@@ -267,7 +265,7 @@ class PaddleTTSConnectionHandler:
tts_engine (TTSEngine): The TTS engine
tts_engine (TTSEngine): The TTS engine
"""
"""
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleTTSConnectionHandler to process the tts request"
)
"Create PaddleTTSConnectionHandler to process the tts request"
)
self
.
tts_engine
=
tts_engine
self
.
tts_engine
=
tts_engine
...
...
paddlespeech/server/engine/tts/online/python/tts_engine.py
浏览文件 @
9c4763ec
...
@@ -102,7 +102,7 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -102,7 +102,7 @@ class TTSServerExecutor(TTSExecutor):
Init model and other resources from a specific path.
Init model and other resources from a specific path.
"""
"""
if
hasattr
(
self
,
'am_inference'
)
and
hasattr
(
self
,
'voc_inference'
):
if
hasattr
(
self
,
'am_inference'
)
and
hasattr
(
self
,
'voc_inference'
):
logger
.
info
(
'Models had been initialized.'
)
logger
.
debug
(
'Models had been initialized.'
)
return
return
# am model info
# am model info
if
am_ckpt
is
None
or
am_config
is
None
or
am_stat
is
None
or
phones_dict
is
None
:
if
am_ckpt
is
None
or
am_config
is
None
or
am_stat
is
None
or
phones_dict
is
None
:
...
@@ -128,17 +128,15 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -128,17 +128,15 @@ class TTSServerExecutor(TTSExecutor):
# must have phones_dict in acoustic
# must have phones_dict in acoustic
self
.
phones_dict
=
os
.
path
.
join
(
self
.
phones_dict
=
os
.
path
.
join
(
self
.
am_res_path
,
self
.
task_resource
.
res_dict
[
'phones_dict'
])
self
.
am_res_path
,
self
.
task_resource
.
res_dict
[
'phones_dict'
])
print
(
"self.phones_dict:"
,
self
.
phones_dict
)
logger
.
debug
(
self
.
am_res_path
)
logger
.
info
(
self
.
am_res_path
)
logger
.
debug
(
self
.
am_config
)
logger
.
info
(
self
.
am_config
)
logger
.
debug
(
self
.
am_ckpt
)
logger
.
info
(
self
.
am_ckpt
)
else
:
else
:
self
.
am_config
=
os
.
path
.
abspath
(
am_config
)
self
.
am_config
=
os
.
path
.
abspath
(
am_config
)
self
.
am_ckpt
=
os
.
path
.
abspath
(
am_ckpt
)
self
.
am_ckpt
=
os
.
path
.
abspath
(
am_ckpt
)
self
.
am_stat
=
os
.
path
.
abspath
(
am_stat
)
self
.
am_stat
=
os
.
path
.
abspath
(
am_stat
)
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
am_config
))
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
am_config
))
print
(
"self.phones_dict:"
,
self
.
phones_dict
)
self
.
tones_dict
=
None
self
.
tones_dict
=
None
self
.
speaker_dict
=
None
self
.
speaker_dict
=
None
...
@@ -165,9 +163,9 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -165,9 +163,9 @@ class TTSServerExecutor(TTSExecutor):
self
.
voc_stat
=
os
.
path
.
join
(
self
.
voc_stat
=
os
.
path
.
join
(
self
.
voc_res_path
,
self
.
voc_res_path
,
self
.
task_resource
.
voc_res_dict
[
'speech_stats'
])
self
.
task_resource
.
voc_res_dict
[
'speech_stats'
])
logger
.
info
(
self
.
voc_res_path
)
logger
.
debug
(
self
.
voc_res_path
)
logger
.
info
(
self
.
voc_config
)
logger
.
debug
(
self
.
voc_config
)
logger
.
info
(
self
.
voc_ckpt
)
logger
.
debug
(
self
.
voc_ckpt
)
else
:
else
:
self
.
voc_config
=
os
.
path
.
abspath
(
voc_config
)
self
.
voc_config
=
os
.
path
.
abspath
(
voc_config
)
self
.
voc_ckpt
=
os
.
path
.
abspath
(
voc_ckpt
)
self
.
voc_ckpt
=
os
.
path
.
abspath
(
voc_ckpt
)
...
@@ -184,7 +182,6 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -184,7 +182,6 @@ class TTSServerExecutor(TTSExecutor):
with
open
(
self
.
phones_dict
,
"r"
)
as
f
:
with
open
(
self
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
self
.
vocab_size
=
len
(
phn_id
)
self
.
vocab_size
=
len
(
phn_id
)
print
(
"vocab_size:"
,
self
.
vocab_size
)
# frontend
# frontend
if
lang
==
'zh'
:
if
lang
==
'zh'
:
...
@@ -194,7 +191,6 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -194,7 +191,6 @@ class TTSServerExecutor(TTSExecutor):
elif
lang
==
'en'
:
elif
lang
==
'en'
:
self
.
frontend
=
English
(
phone_vocab_path
=
self
.
phones_dict
)
self
.
frontend
=
English
(
phone_vocab_path
=
self
.
phones_dict
)
print
(
"frontend done!"
)
# am infer info
# am infer info
self
.
am_name
=
am
[:
am
.
rindex
(
'_'
)]
self
.
am_name
=
am
[:
am
.
rindex
(
'_'
)]
...
@@ -209,7 +205,6 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -209,7 +205,6 @@ class TTSServerExecutor(TTSExecutor):
self
.
am_name
+
'_inference'
)
self
.
am_name
+
'_inference'
)
self
.
am_inference
=
am_inference_class
(
am_normalizer
,
am
)
self
.
am_inference
=
am_inference_class
(
am_normalizer
,
am
)
self
.
am_inference
.
eval
()
self
.
am_inference
.
eval
()
print
(
"acoustic model done!"
)
# voc infer info
# voc infer info
self
.
voc_name
=
voc
[:
voc
.
rindex
(
'_'
)]
self
.
voc_name
=
voc
[:
voc
.
rindex
(
'_'
)]
...
@@ -220,7 +215,6 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -220,7 +215,6 @@ class TTSServerExecutor(TTSExecutor):
'_inference'
)
'_inference'
)
self
.
voc_inference
=
voc_inference_class
(
voc_normalizer
,
voc
)
self
.
voc_inference
=
voc_inference_class
(
voc_normalizer
,
voc
)
self
.
voc_inference
.
eval
()
self
.
voc_inference
.
eval
()
print
(
"voc done!"
)
class
TTSEngine
(
BaseEngine
):
class
TTSEngine
(
BaseEngine
):
...
@@ -309,7 +303,7 @@ class PaddleTTSConnectionHandler:
...
@@ -309,7 +303,7 @@ class PaddleTTSConnectionHandler:
tts_engine (TTSEngine): The TTS engine
tts_engine (TTSEngine): The TTS engine
"""
"""
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleTTSConnectionHandler to process the tts request"
)
"Create PaddleTTSConnectionHandler to process the tts request"
)
self
.
tts_engine
=
tts_engine
self
.
tts_engine
=
tts_engine
...
@@ -369,7 +363,7 @@ class PaddleTTSConnectionHandler:
...
@@ -369,7 +363,7 @@ class PaddleTTSConnectionHandler:
text
,
merge_sentences
=
merge_sentences
)
text
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
phone_ids
=
input_ids
[
"phone_ids"
]
else
:
else
:
print
(
"lang should in {'zh', 'en'}!"
)
logger
.
error
(
"lang should in {'zh', 'en'}!"
)
frontend_et
=
time
.
time
()
frontend_et
=
time
.
time
()
self
.
frontend_time
=
frontend_et
-
frontend_st
self
.
frontend_time
=
frontend_et
-
frontend_st
...
...
paddlespeech/server/engine/tts/paddleinference/tts_engine.py
浏览文件 @
9c4763ec
...
@@ -65,7 +65,7 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -65,7 +65,7 @@ class TTSServerExecutor(TTSExecutor):
Init model and other resources from a specific path.
Init model and other resources from a specific path.
"""
"""
if
hasattr
(
self
,
'am_predictor'
)
and
hasattr
(
self
,
'voc_predictor'
):
if
hasattr
(
self
,
'am_predictor'
)
and
hasattr
(
self
,
'voc_predictor'
):
logger
.
info
(
'Models had been initialized.'
)
logger
.
debug
(
'Models had been initialized.'
)
return
return
# am
# am
if
am_model
is
None
or
am_params
is
None
or
phones_dict
is
None
:
if
am_model
is
None
or
am_params
is
None
or
phones_dict
is
None
:
...
@@ -91,16 +91,16 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -91,16 +91,16 @@ class TTSServerExecutor(TTSExecutor):
self
.
am_res_path
,
self
.
task_resource
.
res_dict
[
'phones_dict'
])
self
.
am_res_path
,
self
.
task_resource
.
res_dict
[
'phones_dict'
])
self
.
am_sample_rate
=
self
.
task_resource
.
res_dict
[
'sample_rate'
]
self
.
am_sample_rate
=
self
.
task_resource
.
res_dict
[
'sample_rate'
]
logger
.
info
(
self
.
am_res_path
)
logger
.
debug
(
self
.
am_res_path
)
logger
.
info
(
self
.
am_model
)
logger
.
debug
(
self
.
am_model
)
logger
.
info
(
self
.
am_params
)
logger
.
debug
(
self
.
am_params
)
else
:
else
:
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
self
.
am_model
=
os
.
path
.
abspath
(
am_model
)
self
.
am_params
=
os
.
path
.
abspath
(
am_params
)
self
.
am_params
=
os
.
path
.
abspath
(
am_params
)
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
phones_dict
=
os
.
path
.
abspath
(
phones_dict
)
self
.
am_sample_rate
=
am_sample_rate
self
.
am_sample_rate
=
am_sample_rate
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
am_model
))
self
.
am_res_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
am_model
))
logger
.
info
(
"self.phones_dict: {}"
.
format
(
self
.
phones_dict
))
logger
.
debug
(
"self.phones_dict: {}"
.
format
(
self
.
phones_dict
))
# for speedyspeech
# for speedyspeech
self
.
tones_dict
=
None
self
.
tones_dict
=
None
...
@@ -139,9 +139,9 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -139,9 +139,9 @@ class TTSServerExecutor(TTSExecutor):
self
.
voc_res_path
,
self
.
task_resource
.
voc_res_dict
[
'params'
])
self
.
voc_res_path
,
self
.
task_resource
.
voc_res_dict
[
'params'
])
self
.
voc_sample_rate
=
self
.
task_resource
.
voc_res_dict
[
self
.
voc_sample_rate
=
self
.
task_resource
.
voc_res_dict
[
'sample_rate'
]
'sample_rate'
]
logger
.
info
(
self
.
voc_res_path
)
logger
.
debug
(
self
.
voc_res_path
)
logger
.
info
(
self
.
voc_model
)
logger
.
debug
(
self
.
voc_model
)
logger
.
info
(
self
.
voc_params
)
logger
.
debug
(
self
.
voc_params
)
else
:
else
:
self
.
voc_model
=
os
.
path
.
abspath
(
voc_model
)
self
.
voc_model
=
os
.
path
.
abspath
(
voc_model
)
self
.
voc_params
=
os
.
path
.
abspath
(
voc_params
)
self
.
voc_params
=
os
.
path
.
abspath
(
voc_params
)
...
@@ -156,21 +156,21 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -156,21 +156,21 @@ class TTSServerExecutor(TTSExecutor):
with
open
(
self
.
phones_dict
,
"r"
)
as
f
:
with
open
(
self
.
phones_dict
,
"r"
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
vocab_size
=
len
(
phn_id
)
vocab_size
=
len
(
phn_id
)
logger
.
info
(
"vocab_size: {}"
.
format
(
vocab_size
))
logger
.
debug
(
"vocab_size: {}"
.
format
(
vocab_size
))
tone_size
=
None
tone_size
=
None
if
self
.
tones_dict
:
if
self
.
tones_dict
:
with
open
(
self
.
tones_dict
,
"r"
)
as
f
:
with
open
(
self
.
tones_dict
,
"r"
)
as
f
:
tone_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
tone_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
tone_size
=
len
(
tone_id
)
tone_size
=
len
(
tone_id
)
logger
.
info
(
"tone_size: {}"
.
format
(
tone_size
))
logger
.
debug
(
"tone_size: {}"
.
format
(
tone_size
))
spk_num
=
None
spk_num
=
None
if
self
.
speaker_dict
:
if
self
.
speaker_dict
:
with
open
(
self
.
speaker_dict
,
'rt'
)
as
f
:
with
open
(
self
.
speaker_dict
,
'rt'
)
as
f
:
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
spk_num
=
len
(
spk_id
)
spk_num
=
len
(
spk_id
)
logger
.
info
(
"spk_num: {}"
.
format
(
spk_num
))
logger
.
debug
(
"spk_num: {}"
.
format
(
spk_num
))
# frontend
# frontend
if
lang
==
'zh'
:
if
lang
==
'zh'
:
...
@@ -180,7 +180,7 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -180,7 +180,7 @@ class TTSServerExecutor(TTSExecutor):
elif
lang
==
'en'
:
elif
lang
==
'en'
:
self
.
frontend
=
English
(
phone_vocab_path
=
self
.
phones_dict
)
self
.
frontend
=
English
(
phone_vocab_path
=
self
.
phones_dict
)
logger
.
info
(
"frontend done!"
)
logger
.
debug
(
"frontend done!"
)
# Create am predictor
# Create am predictor
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor_conf
=
am_predictor_conf
...
@@ -188,7 +188,7 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -188,7 +188,7 @@ class TTSServerExecutor(TTSExecutor):
model_file
=
self
.
am_model
,
model_file
=
self
.
am_model
,
params_file
=
self
.
am_params
,
params_file
=
self
.
am_params
,
predictor_conf
=
self
.
am_predictor_conf
)
predictor_conf
=
self
.
am_predictor_conf
)
logger
.
info
(
"Create AM predictor successfully."
)
logger
.
debug
(
"Create AM predictor successfully."
)
# Create voc predictor
# Create voc predictor
self
.
voc_predictor_conf
=
voc_predictor_conf
self
.
voc_predictor_conf
=
voc_predictor_conf
...
@@ -196,7 +196,7 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -196,7 +196,7 @@ class TTSServerExecutor(TTSExecutor):
model_file
=
self
.
voc_model
,
model_file
=
self
.
voc_model
,
params_file
=
self
.
voc_params
,
params_file
=
self
.
voc_params
,
predictor_conf
=
self
.
voc_predictor_conf
)
predictor_conf
=
self
.
voc_predictor_conf
)
logger
.
info
(
"Create Vocoder predictor successfully."
)
logger
.
debug
(
"Create Vocoder predictor successfully."
)
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
infer
(
self
,
def
infer
(
self
,
...
@@ -328,7 +328,8 @@ class TTSEngine(BaseEngine):
...
@@ -328,7 +328,8 @@ class TTSEngine(BaseEngine):
logger
.
error
(
e
)
logger
.
error
(
e
)
return
False
return
False
logger
.
info
(
"Initialize TTS server engine successfully."
)
logger
.
info
(
"Initialize TTS server engine successfully on device: %s."
%
(
self
.
device
))
return
True
return
True
...
@@ -340,7 +341,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
...
@@ -340,7 +341,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
tts_engine (TTSEngine): The TTS engine
tts_engine (TTSEngine): The TTS engine
"""
"""
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleTTSConnectionHandler to process the tts request"
)
"Create PaddleTTSConnectionHandler to process the tts request"
)
self
.
tts_engine
=
tts_engine
self
.
tts_engine
=
tts_engine
...
@@ -378,23 +379,23 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
...
@@ -378,23 +379,23 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
if
target_fs
==
0
or
target_fs
>
original_fs
:
if
target_fs
==
0
or
target_fs
>
original_fs
:
target_fs
=
original_fs
target_fs
=
original_fs
wav_tar_fs
=
wav
wav_tar_fs
=
wav
logger
.
info
(
logger
.
debug
(
"The sample rate of synthesized audio is the same as model, which is {}Hz"
.
"The sample rate of synthesized audio is the same as model, which is {}Hz"
.
format
(
original_fs
))
format
(
original_fs
))
else
:
else
:
wav_tar_fs
=
librosa
.
resample
(
wav_tar_fs
=
librosa
.
resample
(
np
.
squeeze
(
wav
),
original_fs
,
target_fs
)
np
.
squeeze
(
wav
),
original_fs
,
target_fs
)
logger
.
info
(
logger
.
debug
(
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully."
.
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully."
.
format
(
original_fs
,
target_fs
))
format
(
original_fs
,
target_fs
))
# transform volume
# transform volume
wav_vol
=
wav_tar_fs
*
volume
wav_vol
=
wav_tar_fs
*
volume
logger
.
info
(
"Transform the volume of the audio successfully."
)
logger
.
debug
(
"Transform the volume of the audio successfully."
)
# transform speed
# transform speed
try
:
# windows not support soxbindings
try
:
# windows not support soxbindings
wav_speed
=
change_speed
(
wav_vol
,
speed
,
target_fs
)
wav_speed
=
change_speed
(
wav_vol
,
speed
,
target_fs
)
logger
.
info
(
"Transform the speed of the audio successfully."
)
logger
.
debug
(
"Transform the speed of the audio successfully."
)
except
ServerBaseException
:
except
ServerBaseException
:
raise
ServerBaseException
(
raise
ServerBaseException
(
ErrorCode
.
SERVER_INTERNAL_ERR
,
ErrorCode
.
SERVER_INTERNAL_ERR
,
...
@@ -411,7 +412,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
...
@@ -411,7 +412,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
wavfile
.
write
(
buf
,
target_fs
,
wav_speed
)
wavfile
.
write
(
buf
,
target_fs
,
wav_speed
)
base64_bytes
=
base64
.
b64encode
(
buf
.
read
())
base64_bytes
=
base64
.
b64encode
(
buf
.
read
())
wav_base64
=
base64_bytes
.
decode
(
'utf-8'
)
wav_base64
=
base64_bytes
.
decode
(
'utf-8'
)
logger
.
info
(
"Audio to string successfully."
)
logger
.
debug
(
"Audio to string successfully."
)
# save audio
# save audio
if
audio_path
is
not
None
:
if
audio_path
is
not
None
:
...
@@ -499,15 +500,15 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
...
@@ -499,15 +500,15 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
logger
.
error
(
e
)
logger
.
error
(
e
)
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
logger
.
info
(
"AM model: {}"
.
format
(
self
.
config
.
am
))
logger
.
debug
(
"AM model: {}"
.
format
(
self
.
config
.
am
))
logger
.
info
(
"Vocoder model: {}"
.
format
(
self
.
config
.
voc
))
logger
.
debug
(
"Vocoder model: {}"
.
format
(
self
.
config
.
voc
))
logger
.
info
(
"Language: {}"
.
format
(
lang
))
logger
.
debug
(
"Language: {}"
.
format
(
lang
))
logger
.
info
(
"tts engine type: python"
)
logger
.
info
(
"tts engine type: python"
)
logger
.
info
(
"audio duration: {}"
.
format
(
duration
))
logger
.
info
(
"audio duration: {}"
.
format
(
duration
))
logger
.
info
(
"frontend inference time: {}"
.
format
(
self
.
frontend_time
))
logger
.
debug
(
"frontend inference time: {}"
.
format
(
self
.
frontend_time
))
logger
.
info
(
"AM inference time: {}"
.
format
(
self
.
am_time
))
logger
.
debug
(
"AM inference time: {}"
.
format
(
self
.
am_time
))
logger
.
info
(
"Vocoder inference time: {}"
.
format
(
self
.
voc_time
))
logger
.
debug
(
"Vocoder inference time: {}"
.
format
(
self
.
voc_time
))
logger
.
info
(
"total inference time: {}"
.
format
(
infer_time
))
logger
.
info
(
"total inference time: {}"
.
format
(
infer_time
))
logger
.
info
(
logger
.
info
(
"postprocess (change speed, volume, target sample rate) time: {}"
.
"postprocess (change speed, volume, target sample rate) time: {}"
.
...
@@ -515,6 +516,6 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
...
@@ -515,6 +516,6 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
logger
.
info
(
"total generate audio time: {}"
.
format
(
infer_time
+
logger
.
info
(
"total generate audio time: {}"
.
format
(
infer_time
+
postprocess_time
))
postprocess_time
))
logger
.
info
(
"RTF: {}"
.
format
(
rtf
))
logger
.
info
(
"RTF: {}"
.
format
(
rtf
))
logger
.
info
(
"device: {}"
.
format
(
self
.
tts_engine
.
device
))
logger
.
debug
(
"device: {}"
.
format
(
self
.
tts_engine
.
device
))
return
lang
,
target_sample_rate
,
duration
,
wav_base64
return
lang
,
target_sample_rate
,
duration
,
wav_base64
paddlespeech/server/engine/tts/python/tts_engine.py
浏览文件 @
9c4763ec
...
@@ -105,7 +105,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
...
@@ -105,7 +105,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
tts_engine (TTSEngine): The TTS engine
tts_engine (TTSEngine): The TTS engine
"""
"""
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleTTSConnectionHandler to process the tts request"
)
"Create PaddleTTSConnectionHandler to process the tts request"
)
self
.
tts_engine
=
tts_engine
self
.
tts_engine
=
tts_engine
...
@@ -143,23 +143,23 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
...
@@ -143,23 +143,23 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
if
target_fs
==
0
or
target_fs
>
original_fs
:
if
target_fs
==
0
or
target_fs
>
original_fs
:
target_fs
=
original_fs
target_fs
=
original_fs
wav_tar_fs
=
wav
wav_tar_fs
=
wav
logger
.
info
(
logger
.
debug
(
"The sample rate of synthesized audio is the same as model, which is {}Hz"
.
"The sample rate of synthesized audio is the same as model, which is {}Hz"
.
format
(
original_fs
))
format
(
original_fs
))
else
:
else
:
wav_tar_fs
=
librosa
.
resample
(
wav_tar_fs
=
librosa
.
resample
(
np
.
squeeze
(
wav
),
original_fs
,
target_fs
)
np
.
squeeze
(
wav
),
original_fs
,
target_fs
)
logger
.
info
(
logger
.
debug
(
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully."
.
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully."
.
format
(
original_fs
,
target_fs
))
format
(
original_fs
,
target_fs
))
# transform volume
# transform volume
wav_vol
=
wav_tar_fs
*
volume
wav_vol
=
wav_tar_fs
*
volume
logger
.
info
(
"Transform the volume of the audio successfully."
)
logger
.
debug
(
"Transform the volume of the audio successfully."
)
# transform speed
# transform speed
try
:
# windows not support soxbindings
try
:
# windows not support soxbindings
wav_speed
=
change_speed
(
wav_vol
,
speed
,
target_fs
)
wav_speed
=
change_speed
(
wav_vol
,
speed
,
target_fs
)
logger
.
info
(
"Transform the speed of the audio successfully."
)
logger
.
debug
(
"Transform the speed of the audio successfully."
)
except
ServerBaseException
:
except
ServerBaseException
:
raise
ServerBaseException
(
raise
ServerBaseException
(
ErrorCode
.
SERVER_INTERNAL_ERR
,
ErrorCode
.
SERVER_INTERNAL_ERR
,
...
@@ -176,7 +176,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
...
@@ -176,7 +176,7 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
wavfile
.
write
(
buf
,
target_fs
,
wav_speed
)
wavfile
.
write
(
buf
,
target_fs
,
wav_speed
)
base64_bytes
=
base64
.
b64encode
(
buf
.
read
())
base64_bytes
=
base64
.
b64encode
(
buf
.
read
())
wav_base64
=
base64_bytes
.
decode
(
'utf-8'
)
wav_base64
=
base64_bytes
.
decode
(
'utf-8'
)
logger
.
info
(
"Audio to string successfully."
)
logger
.
debug
(
"Audio to string successfully."
)
# save audio
# save audio
if
audio_path
is
not
None
:
if
audio_path
is
not
None
:
...
@@ -264,15 +264,15 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
...
@@ -264,15 +264,15 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
logger
.
error
(
e
)
logger
.
error
(
e
)
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
logger
.
info
(
"AM model: {}"
.
format
(
self
.
config
.
am
))
logger
.
debug
(
"AM model: {}"
.
format
(
self
.
config
.
am
))
logger
.
info
(
"Vocoder model: {}"
.
format
(
self
.
config
.
voc
))
logger
.
debug
(
"Vocoder model: {}"
.
format
(
self
.
config
.
voc
))
logger
.
info
(
"Language: {}"
.
format
(
lang
))
logger
.
debug
(
"Language: {}"
.
format
(
lang
))
logger
.
info
(
"tts engine type: python"
)
logger
.
info
(
"tts engine type: python"
)
logger
.
info
(
"audio duration: {}"
.
format
(
duration
))
logger
.
info
(
"audio duration: {}"
.
format
(
duration
))
logger
.
info
(
"frontend inference time: {}"
.
format
(
self
.
frontend_time
))
logger
.
debug
(
"frontend inference time: {}"
.
format
(
self
.
frontend_time
))
logger
.
info
(
"AM inference time: {}"
.
format
(
self
.
am_time
))
logger
.
debug
(
"AM inference time: {}"
.
format
(
self
.
am_time
))
logger
.
info
(
"Vocoder inference time: {}"
.
format
(
self
.
voc_time
))
logger
.
debug
(
"Vocoder inference time: {}"
.
format
(
self
.
voc_time
))
logger
.
info
(
"total inference time: {}"
.
format
(
infer_time
))
logger
.
info
(
"total inference time: {}"
.
format
(
infer_time
))
logger
.
info
(
logger
.
info
(
"postprocess (change speed, volume, target sample rate) time: {}"
.
"postprocess (change speed, volume, target sample rate) time: {}"
.
...
@@ -280,6 +280,6 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
...
@@ -280,6 +280,6 @@ class PaddleTTSConnectionHandler(TTSServerExecutor):
logger
.
info
(
"total generate audio time: {}"
.
format
(
infer_time
+
logger
.
info
(
"total generate audio time: {}"
.
format
(
infer_time
+
postprocess_time
))
postprocess_time
))
logger
.
info
(
"RTF: {}"
.
format
(
rtf
))
logger
.
info
(
"RTF: {}"
.
format
(
rtf
))
logger
.
info
(
"device: {}"
.
format
(
self
.
tts_engine
.
device
))
logger
.
debug
(
"device: {}"
.
format
(
self
.
tts_engine
.
device
))
return
lang
,
target_sample_rate
,
duration
,
wav_base64
return
lang
,
target_sample_rate
,
duration
,
wav_base64
paddlespeech/server/engine/vector/python/vector_engine.py
浏览文件 @
9c4763ec
...
@@ -33,7 +33,7 @@ class PaddleVectorConnectionHandler:
...
@@ -33,7 +33,7 @@ class PaddleVectorConnectionHandler:
vector_engine (VectorEngine): The Vector engine
vector_engine (VectorEngine): The Vector engine
"""
"""
super
().
__init__
()
super
().
__init__
()
logger
.
info
(
logger
.
debug
(
"Create PaddleVectorConnectionHandler to process the vector request"
)
"Create PaddleVectorConnectionHandler to process the vector request"
)
self
.
vector_engine
=
vector_engine
self
.
vector_engine
=
vector_engine
self
.
executor
=
self
.
vector_engine
.
executor
self
.
executor
=
self
.
vector_engine
.
executor
...
@@ -54,7 +54,7 @@ class PaddleVectorConnectionHandler:
...
@@ -54,7 +54,7 @@ class PaddleVectorConnectionHandler:
Returns:
Returns:
str: the punctuation text
str: the punctuation text
"""
"""
logger
.
info
(
logger
.
debug
(
f
"start to extract the do vector
{
self
.
task
}
from the http request"
)
f
"start to extract the do vector
{
self
.
task
}
from the http request"
)
if
self
.
task
==
"spk"
and
task
==
"spk"
:
if
self
.
task
==
"spk"
and
task
==
"spk"
:
embedding
=
self
.
extract_audio_embedding
(
audio_data
)
embedding
=
self
.
extract_audio_embedding
(
audio_data
)
...
@@ -81,17 +81,17 @@ class PaddleVectorConnectionHandler:
...
@@ -81,17 +81,17 @@ class PaddleVectorConnectionHandler:
Returns:
Returns:
float: the score between enroll and test audio
float: the score between enroll and test audio
"""
"""
logger
.
info
(
"start to extract the enroll audio embedding"
)
logger
.
debug
(
"start to extract the enroll audio embedding"
)
enroll_emb
=
self
.
extract_audio_embedding
(
enroll_audio
)
enroll_emb
=
self
.
extract_audio_embedding
(
enroll_audio
)
logger
.
info
(
"start to extract the test audio embedding"
)
logger
.
debug
(
"start to extract the test audio embedding"
)
test_emb
=
self
.
extract_audio_embedding
(
test_audio
)
test_emb
=
self
.
extract_audio_embedding
(
test_audio
)
logger
.
info
(
logger
.
debug
(
"start to get the score between the enroll and test embedding"
)
"start to get the score between the enroll and test embedding"
)
score
=
self
.
executor
.
get_embeddings_score
(
enroll_emb
,
test_emb
)
score
=
self
.
executor
.
get_embeddings_score
(
enroll_emb
,
test_emb
)
logger
.
info
(
f
"get the enroll vs test score:
{
score
}
"
)
logger
.
debug
(
f
"get the enroll vs test score:
{
score
}
"
)
return
score
return
score
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
...
@@ -106,11 +106,12 @@ class PaddleVectorConnectionHandler:
...
@@ -106,11 +106,12 @@ class PaddleVectorConnectionHandler:
# because the soundfile will change the io.BytesIO(audio) to the end
# because the soundfile will change the io.BytesIO(audio) to the end
# thus we should convert the base64 string to io.BytesIO when we need the audio data
# thus we should convert the base64 string to io.BytesIO when we need the audio data
if
not
self
.
executor
.
_check
(
io
.
BytesIO
(
audio
),
sample_rate
):
if
not
self
.
executor
.
_check
(
io
.
BytesIO
(
audio
),
sample_rate
):
logger
.
info
(
"check the audio sample rate occurs error"
)
logger
.
debug
(
"check the audio sample rate occurs error"
)
return
np
.
array
([
0.0
])
return
np
.
array
([
0.0
])
waveform
,
sr
=
load_audio
(
io
.
BytesIO
(
audio
))
waveform
,
sr
=
load_audio
(
io
.
BytesIO
(
audio
))
logger
.
info
(
f
"load the audio sample points, shape is:
{
waveform
.
shape
}
"
)
logger
.
debug
(
f
"load the audio sample points, shape is:
{
waveform
.
shape
}
"
)
# stage 2: get the audio feat
# stage 2: get the audio feat
# Note: Now we only support fbank feature
# Note: Now we only support fbank feature
...
@@ -121,9 +122,9 @@ class PaddleVectorConnectionHandler:
...
@@ -121,9 +122,9 @@ class PaddleVectorConnectionHandler:
n_mels
=
self
.
config
.
n_mels
,
n_mels
=
self
.
config
.
n_mels
,
window_size
=
self
.
config
.
window_size
,
window_size
=
self
.
config
.
window_size
,
hop_length
=
self
.
config
.
hop_size
)
hop_length
=
self
.
config
.
hop_size
)
logger
.
info
(
f
"extract the audio feats, shape is:
{
feats
.
shape
}
"
)
logger
.
debug
(
f
"extract the audio feats, shape is:
{
feats
.
shape
}
"
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
info
(
f
"feats occurs exception
{
e
}
"
)
logger
.
error
(
f
"feats occurs exception
{
e
}
"
)
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
feats
=
paddle
.
to_tensor
(
feats
).
unsqueeze
(
0
)
feats
=
paddle
.
to_tensor
(
feats
).
unsqueeze
(
0
)
...
@@ -159,7 +160,7 @@ class VectorEngine(BaseEngine):
...
@@ -159,7 +160,7 @@ class VectorEngine(BaseEngine):
"""The Vector Engine
"""The Vector Engine
"""
"""
super
(
VectorEngine
,
self
).
__init__
()
super
(
VectorEngine
,
self
).
__init__
()
logger
.
info
(
"Create the VectorEngine Instance"
)
logger
.
debug
(
"Create the VectorEngine Instance"
)
def
init
(
self
,
config
:
dict
):
def
init
(
self
,
config
:
dict
):
"""Init the Vector Engine
"""Init the Vector Engine
...
@@ -170,7 +171,7 @@ class VectorEngine(BaseEngine):
...
@@ -170,7 +171,7 @@ class VectorEngine(BaseEngine):
Returns:
Returns:
bool: The engine instance flag
bool: The engine instance flag
"""
"""
logger
.
info
(
"Init the vector engine"
)
logger
.
debug
(
"Init the vector engine"
)
try
:
try
:
self
.
config
=
config
self
.
config
=
config
if
self
.
config
.
device
:
if
self
.
config
.
device
:
...
@@ -179,7 +180,7 @@ class VectorEngine(BaseEngine):
...
@@ -179,7 +180,7 @@ class VectorEngine(BaseEngine):
self
.
device
=
paddle
.
get_device
()
self
.
device
=
paddle
.
get_device
()
paddle
.
set_device
(
self
.
device
)
paddle
.
set_device
(
self
.
device
)
logger
.
info
(
f
"Vector Engine set the device:
{
self
.
device
}
"
)
logger
.
debug
(
f
"Vector Engine set the device:
{
self
.
device
}
"
)
except
BaseException
as
e
:
except
BaseException
as
e
:
logger
.
error
(
logger
.
error
(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
...
@@ -196,5 +197,7 @@ class VectorEngine(BaseEngine):
...
@@ -196,5 +197,7 @@ class VectorEngine(BaseEngine):
ckpt_path
=
config
.
ckpt_path
,
ckpt_path
=
config
.
ckpt_path
,
task
=
config
.
task
)
task
=
config
.
task
)
logger
.
info
(
"Init the Vector engine successfully"
)
logger
.
info
(
"Initialize Vector server engine successfully on device: %s."
%
(
self
.
device
))
return
True
return
True
paddlespeech/server/utils/audio_handler.py
浏览文件 @
9c4763ec
...
@@ -138,7 +138,7 @@ class ASRWsAudioHandler:
...
@@ -138,7 +138,7 @@ class ASRWsAudioHandler:
Returns:
Returns:
str: the final asr result
str: the final asr result
"""
"""
logging
.
info
(
"send a message to the server"
)
logging
.
debug
(
"send a message to the server"
)
if
self
.
url
is
None
:
if
self
.
url
is
None
:
logger
.
error
(
"No asr server, please input valid ip and port"
)
logger
.
error
(
"No asr server, please input valid ip and port"
)
...
@@ -160,7 +160,7 @@ class ASRWsAudioHandler:
...
@@ -160,7 +160,7 @@ class ASRWsAudioHandler:
separators
=
(
','
,
': '
))
separators
=
(
','
,
': '
))
await
ws
.
send
(
audio_info
)
await
ws
.
send
(
audio_info
)
msg
=
await
ws
.
recv
()
msg
=
await
ws
.
recv
()
logger
.
info
(
"client receive msg={}"
.
format
(
msg
))
logger
.
debug
(
"client receive msg={}"
.
format
(
msg
))
# 3. send chunk audio data to engine
# 3. send chunk audio data to engine
for
chunk_data
in
self
.
read_wave
(
wavfile_path
):
for
chunk_data
in
self
.
read_wave
(
wavfile_path
):
...
@@ -170,7 +170,7 @@ class ASRWsAudioHandler:
...
@@ -170,7 +170,7 @@ class ASRWsAudioHandler:
if
self
.
punc_server
and
len
(
msg
[
"result"
])
>
0
:
if
self
.
punc_server
and
len
(
msg
[
"result"
])
>
0
:
msg
[
"result"
]
=
self
.
punc_server
.
run
(
msg
[
"result"
])
msg
[
"result"
]
=
self
.
punc_server
.
run
(
msg
[
"result"
])
logger
.
info
(
"client receive msg={}"
.
format
(
msg
))
logger
.
debug
(
"client receive msg={}"
.
format
(
msg
))
# 4. we must send finished signal to the server
# 4. we must send finished signal to the server
audio_info
=
json
.
dumps
(
audio_info
=
json
.
dumps
(
...
@@ -310,7 +310,7 @@ class TTSWsHandler:
...
@@ -310,7 +310,7 @@ class TTSWsHandler:
start_request
=
json
.
dumps
({
"task"
:
"tts"
,
"signal"
:
"start"
})
start_request
=
json
.
dumps
({
"task"
:
"tts"
,
"signal"
:
"start"
})
await
ws
.
send
(
start_request
)
await
ws
.
send
(
start_request
)
msg
=
await
ws
.
recv
()
msg
=
await
ws
.
recv
()
logger
.
info
(
f
"client receive msg=
{
msg
}
"
)
logger
.
debug
(
f
"client receive msg=
{
msg
}
"
)
msg
=
json
.
loads
(
msg
)
msg
=
json
.
loads
(
msg
)
session
=
msg
[
"session"
]
session
=
msg
[
"session"
]
...
@@ -319,7 +319,7 @@ class TTSWsHandler:
...
@@ -319,7 +319,7 @@ class TTSWsHandler:
request
=
json
.
dumps
({
"text"
:
text_base64
})
request
=
json
.
dumps
({
"text"
:
text_base64
})
st
=
time
.
time
()
st
=
time
.
time
()
await
ws
.
send
(
request
)
await
ws
.
send
(
request
)
logging
.
info
(
"send a message to the server"
)
logging
.
debug
(
"send a message to the server"
)
# 4. Process the received response
# 4. Process the received response
message
=
await
ws
.
recv
()
message
=
await
ws
.
recv
()
...
@@ -543,7 +543,6 @@ class VectorHttpHandler:
...
@@ -543,7 +543,6 @@ class VectorHttpHandler:
"sample_rate"
:
sample_rate
,
"sample_rate"
:
sample_rate
,
}
}
logger
.
info
(
self
.
url
)
res
=
requests
.
post
(
url
=
self
.
url
,
data
=
json
.
dumps
(
data
))
res
=
requests
.
post
(
url
=
self
.
url
,
data
=
json
.
dumps
(
data
))
return
res
.
json
()
return
res
.
json
()
...
...
paddlespeech/server/utils/audio_process.py
浏览文件 @
9c4763ec
...
@@ -169,7 +169,7 @@ def save_audio(bytes_data, audio_path, sample_rate: int=24000) -> bool:
...
@@ -169,7 +169,7 @@ def save_audio(bytes_data, audio_path, sample_rate: int=24000) -> bool:
sample_rate
=
sample_rate
)
sample_rate
=
sample_rate
)
os
.
remove
(
"./tmp.pcm"
)
os
.
remove
(
"./tmp.pcm"
)
else
:
else
:
print
(
"Only supports saved audio format is pcm or wav"
)
logger
.
error
(
"Only supports saved audio format is pcm or wav"
)
return
False
return
False
return
True
return
True
paddlespeech/server/utils/onnx_infer.py
浏览文件 @
9c4763ec
...
@@ -20,7 +20,7 @@ from paddlespeech.cli.log import logger
...
@@ -20,7 +20,7 @@ from paddlespeech.cli.log import logger
def
get_sess
(
model_path
:
Optional
[
os
.
PathLike
]
=
None
,
sess_conf
:
dict
=
None
):
def
get_sess
(
model_path
:
Optional
[
os
.
PathLike
]
=
None
,
sess_conf
:
dict
=
None
):
logger
.
info
(
f
"ort sessconf:
{
sess_conf
}
"
)
logger
.
debug
(
f
"ort sessconf:
{
sess_conf
}
"
)
sess_options
=
ort
.
SessionOptions
()
sess_options
=
ort
.
SessionOptions
()
sess_options
.
graph_optimization_level
=
ort
.
GraphOptimizationLevel
.
ORT_ENABLE_ALL
sess_options
.
graph_optimization_level
=
ort
.
GraphOptimizationLevel
.
ORT_ENABLE_ALL
if
sess_conf
.
get
(
'graph_optimization_level'
,
99
)
==
0
:
if
sess_conf
.
get
(
'graph_optimization_level'
,
99
)
==
0
:
...
@@ -34,7 +34,7 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
...
@@ -34,7 +34,7 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
# fastspeech2/mb_melgan can't use trt now!
# fastspeech2/mb_melgan can't use trt now!
if
sess_conf
.
get
(
"use_trt"
,
0
):
if
sess_conf
.
get
(
"use_trt"
,
0
):
providers
=
[
'TensorrtExecutionProvider'
]
providers
=
[
'TensorrtExecutionProvider'
]
logger
.
info
(
f
"ort providers:
{
providers
}
"
)
logger
.
debug
(
f
"ort providers:
{
providers
}
"
)
if
'cpu_threads'
in
sess_conf
:
if
'cpu_threads'
in
sess_conf
:
sess_options
.
intra_op_num_threads
=
sess_conf
.
get
(
"cpu_threads"
,
0
)
sess_options
.
intra_op_num_threads
=
sess_conf
.
get
(
"cpu_threads"
,
0
)
...
...
paddlespeech/server/utils/util.py
浏览文件 @
9c4763ec
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
import
base64
import
base64
import
math
import
math
from
paddlespeech.cli.log
import
logger
def
wav2base64
(
wav_file
:
str
):
def
wav2base64
(
wav_file
:
str
):
"""
"""
...
@@ -61,7 +63,7 @@ def get_chunks(data, block_size, pad_size, step):
...
@@ -61,7 +63,7 @@ def get_chunks(data, block_size, pad_size, step):
elif
step
==
"voc"
:
elif
step
==
"voc"
:
data_len
=
data
.
shape
[
0
]
data_len
=
data
.
shape
[
0
]
else
:
else
:
print
(
"Please set correct type to get chunks, am or voc"
)
logger
.
error
(
"Please set correct type to get chunks, am or voc"
)
chunks
=
[]
chunks
=
[]
n
=
math
.
ceil
(
data_len
/
block_size
)
n
=
math
.
ceil
(
data_len
/
block_size
)
...
@@ -73,7 +75,7 @@ def get_chunks(data, block_size, pad_size, step):
...
@@ -73,7 +75,7 @@ def get_chunks(data, block_size, pad_size, step):
elif
step
==
"voc"
:
elif
step
==
"voc"
:
chunks
.
append
(
data
[
start
:
end
,
:])
chunks
.
append
(
data
[
start
:
end
,
:])
else
:
else
:
print
(
"Please set correct type to get chunks, am or voc"
)
logger
.
error
(
"Please set correct type to get chunks, am or voc"
)
return
chunks
return
chunks
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录