Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
40dde22f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
40dde22f
编写于
4月 19, 2022
作者:
L
lym0302
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
code format, test=doc
上级
00a6236f
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
27 addition
and
17 deletion
+27
-17
paddlespeech/server/engine/tts/online/tts_engine.py
paddlespeech/server/engine/tts/online/tts_engine.py
+27
-17
未找到文件。
paddlespeech/server/engine/tts/online/tts_engine.py
浏览文件 @
40dde22f
...
...
@@ -127,33 +127,40 @@ class TTSServerExecutor(TTSExecutor):
self
.
voc_block
=
voc_block
self
.
voc_pad
=
voc_pad
def
get_model_info
(
self
,
step
,
model_name
,
ckpt
,
stat
):
def
get_model_info
(
self
,
field
:
str
,
model_name
:
str
,
ckpt
:
Optional
[
os
.
PathLike
],
stat
:
Optional
[
os
.
PathLike
]):
"""get model information
Args:
step (string
): am or voc
model_name (str
ing): model type, support fastspeech2, higigan, mb_melgan
ckpt (
string
): ckpt file
stat (
string
): stat file, including mean and standard deviation
field (str
): am or voc
model_name (str
): model type, support fastspeech2, higigan, mb_melgan
ckpt (
Optional[os.PathLike]
): ckpt file
stat (
Optional[os.PathLike]
): stat file, including mean and standard deviation
Returns:
model, model_mu, model_std
[module]: model module
[Tensor]: mean
[Tensor]: standard deviation
"""
model_class
=
dynamic_import
(
model_name
,
model_alias
)
if
step
==
"am"
:
if
field
==
"am"
:
odim
=
self
.
am_config
.
n_mels
model
=
model_class
(
idim
=
self
.
vocab_size
,
odim
=
odim
,
**
self
.
am_config
[
"model"
])
model
.
set_state_dict
(
paddle
.
load
(
ckpt
)[
"main_params"
])
elif
step
==
"voc"
:
elif
field
==
"voc"
:
model
=
model_class
(
**
self
.
voc_config
[
"generator_params"
])
model
.
set_state_dict
(
paddle
.
load
(
ckpt
)[
"generator_params"
])
model
.
remove_weight_norm
()
else
:
logger
.
error
(
"Please set correct
step
, am or voc"
)
logger
.
error
(
"Please set correct
field
, am or voc"
)
model
.
eval
()
model_mu
,
model_std
=
np
.
load
(
stat
)
...
...
@@ -346,7 +353,8 @@ class TTSServerExecutor(TTSExecutor):
voc_block
=
self
.
voc_block
voc_pad
=
self
.
voc_pad
voc_upsample
=
self
.
voc_config
.
n_shift
flag
=
1
# first_flag 用于标记首包
first_flag
=
1
get_tone_ids
=
False
merge_sentences
=
False
...
...
@@ -376,7 +384,7 @@ class TTSServerExecutor(TTSExecutor):
if
am
==
"fastspeech2_csmsc"
:
# am
mel
=
self
.
am_inference
(
part_phone_ids
)
if
flag
==
1
:
if
f
irst_f
lag
==
1
:
first_am_et
=
time
.
time
()
self
.
first_am_infer
=
first_am_et
-
frontend_et
...
...
@@ -388,11 +396,11 @@ class TTSServerExecutor(TTSExecutor):
sub_wav
=
self
.
voc_inference
(
mel_chunk
)
sub_wav
=
self
.
depadding
(
sub_wav
,
voc_chunk_num
,
i
,
voc_block
,
voc_pad
,
voc_upsample
)
if
flag
==
1
:
if
f
irst_f
lag
==
1
:
first_voc_et
=
time
.
time
()
self
.
first_voc_infer
=
first_voc_et
-
first_am_et
self
.
first_response_time
=
first_voc_et
-
frontend_st
flag
=
0
f
irst_f
lag
=
0
yield
sub_wav
...
...
@@ -427,9 +435,10 @@ class TTSServerExecutor(TTSExecutor):
(
mel_streaming
,
sub_mel
),
axis
=
0
)
# streaming voc
# 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理
while
(
mel_streaming
.
shape
[
0
]
>=
end
and
voc_chunk_id
<
voc_chunk_num
):
if
flag
==
1
:
if
f
irst_f
lag
==
1
:
first_am_et
=
time
.
time
()
self
.
first_am_infer
=
first_am_et
-
frontend_et
voc_chunk
=
mel_streaming
[
start
:
end
,
:]
...
...
@@ -439,11 +448,11 @@ class TTSServerExecutor(TTSExecutor):
sub_wav
=
self
.
depadding
(
sub_wav
,
voc_chunk_num
,
voc_chunk_id
,
voc_block
,
voc_pad
,
voc_upsample
)
if
flag
==
1
:
if
f
irst_f
lag
==
1
:
first_voc_et
=
time
.
time
()
self
.
first_voc_infer
=
first_voc_et
-
first_am_et
self
.
first_response_time
=
first_voc_et
-
frontend_st
flag
=
0
f
irst_f
lag
=
0
yield
sub_wav
...
...
@@ -470,7 +479,8 @@ class TTSEngine(BaseEngine):
def
__init__
(
self
,
name
=
None
):
"""Initialize TTS server engine
"""
super
(
TTSEngine
,
self
).
__init__
()
#super(TTSEngine, self).__init__()
super
().
__init__
()
def
init
(
self
,
config
:
dict
)
->
bool
:
self
.
config
=
config
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录