Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
adc7c9b4
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
adc7c9b4
编写于
6月 29, 2022
作者:
K
KP
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix unnecessary download present in issue #2067.
上级
a0d1888c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
6 deletion
+22
-6
paddlespeech/cli/tts/infer.py
paddlespeech/cli/tts/infer.py
+15
-2
paddlespeech/resource/resource.py
paddlespeech/resource/resource.py
+7
-4
未找到文件。
paddlespeech/cli/tts/infer.py
浏览文件 @
adc7c9b4
...
@@ -175,14 +175,21 @@ class TTSExecutor(BaseExecutor):
...
@@ -175,14 +175,21 @@ class TTSExecutor(BaseExecutor):
if
hasattr
(
self
,
'am_inference'
)
and
hasattr
(
self
,
'voc_inference'
):
if
hasattr
(
self
,
'am_inference'
)
and
hasattr
(
self
,
'voc_inference'
):
logger
.
info
(
'Models had been initialized.'
)
logger
.
info
(
'Models had been initialized.'
)
return
return
# am
# am
if
am_ckpt
is
None
or
am_config
is
None
or
am_stat
is
None
or
phones_dict
is
None
:
use_pretrained_am
=
True
else
:
use_pretrained_am
=
False
am_tag
=
am
+
'-'
+
lang
am_tag
=
am
+
'-'
+
lang
self
.
task_resource
.
set_task_model
(
self
.
task_resource
.
set_task_model
(
model_tag
=
am_tag
,
model_tag
=
am_tag
,
model_type
=
0
,
# am
model_type
=
0
,
# am
skip_download
=
not
use_pretrained_am
,
version
=
None
,
# default version
version
=
None
,
# default version
)
)
if
am_ckpt
is
None
or
am_config
is
None
or
am_stat
is
None
or
phones_dict
is
None
:
if
use_pretrained_am
:
self
.
am_res_path
=
self
.
task_resource
.
res_dir
self
.
am_res_path
=
self
.
task_resource
.
res_dir
self
.
am_config
=
os
.
path
.
join
(
self
.
am_res_path
,
self
.
am_config
=
os
.
path
.
join
(
self
.
am_res_path
,
self
.
task_resource
.
res_dict
[
'config'
])
self
.
task_resource
.
res_dict
[
'config'
])
...
@@ -220,13 +227,19 @@ class TTSExecutor(BaseExecutor):
...
@@ -220,13 +227,19 @@ class TTSExecutor(BaseExecutor):
self
.
speaker_dict
=
speaker_dict
self
.
speaker_dict
=
speaker_dict
# voc
# voc
if
voc_ckpt
is
None
or
voc_config
is
None
or
voc_stat
is
None
:
use_pretrained_voc
=
True
else
:
use_pretrained_voc
=
False
voc_tag
=
voc
+
'-'
+
lang
voc_tag
=
voc
+
'-'
+
lang
self
.
task_resource
.
set_task_model
(
self
.
task_resource
.
set_task_model
(
model_tag
=
voc_tag
,
model_tag
=
voc_tag
,
model_type
=
1
,
# vocoder
model_type
=
1
,
# vocoder
skip_download
=
not
use_pretrained_voc
,
version
=
None
,
# default version
version
=
None
,
# default version
)
)
if
voc_ckpt
is
None
or
voc_config
is
None
or
voc_stat
is
None
:
if
use_pretrained_voc
:
self
.
voc_res_path
=
self
.
task_resource
.
voc_res_dir
self
.
voc_res_path
=
self
.
task_resource
.
voc_res_dir
self
.
voc_config
=
os
.
path
.
join
(
self
.
voc_config
=
os
.
path
.
join
(
self
.
voc_res_path
,
self
.
task_resource
.
voc_res_dict
[
'config'
])
self
.
voc_res_path
,
self
.
task_resource
.
voc_res_dict
[
'config'
])
...
...
paddlespeech/resource/resource.py
浏览文件 @
adc7c9b4
...
@@ -60,6 +60,7 @@ class CommonTaskResource:
...
@@ -60,6 +60,7 @@ class CommonTaskResource:
def
set_task_model
(
self
,
def
set_task_model
(
self
,
model_tag
:
str
,
model_tag
:
str
,
model_type
:
int
=
0
,
model_type
:
int
=
0
,
skip_download
:
bool
=
False
,
version
:
Optional
[
str
]
=
None
):
version
:
Optional
[
str
]
=
None
):
"""Set model tag and version of current task.
"""Set model tag and version of current task.
...
@@ -83,16 +84,18 @@ class CommonTaskResource:
...
@@ -83,16 +84,18 @@ class CommonTaskResource:
self
.
version
=
version
self
.
version
=
version
self
.
res_dict
=
self
.
pretrained_models
[
model_tag
][
version
]
self
.
res_dict
=
self
.
pretrained_models
[
model_tag
][
version
]
self
.
_format_path
(
self
.
res_dict
)
self
.
_format_path
(
self
.
res_dict
)
self
.
res_dir
=
self
.
_fetch
(
self
.
res_dict
,
if
not
skip_download
:
self
.
_get_model_dir
(
model_type
))
self
.
res_dir
=
self
.
_fetch
(
self
.
res_dict
,
self
.
_get_model_dir
(
model_type
))
else
:
else
:
assert
self
.
task
==
'tts'
,
'Vocoder will only be used in tts task.'
assert
self
.
task
==
'tts'
,
'Vocoder will only be used in tts task.'
self
.
voc_model_tag
=
model_tag
self
.
voc_model_tag
=
model_tag
self
.
voc_version
=
version
self
.
voc_version
=
version
self
.
voc_res_dict
=
self
.
pretrained_models
[
model_tag
][
version
]
self
.
voc_res_dict
=
self
.
pretrained_models
[
model_tag
][
version
]
self
.
_format_path
(
self
.
voc_res_dict
)
self
.
_format_path
(
self
.
voc_res_dict
)
self
.
voc_res_dir
=
self
.
_fetch
(
self
.
voc_res_dict
,
if
not
skip_download
:
self
.
_get_model_dir
(
model_type
))
self
.
voc_res_dir
=
self
.
_fetch
(
self
.
voc_res_dict
,
self
.
_get_model_dir
(
model_type
))
@
staticmethod
@
staticmethod
def
get_model_class
(
model_name
)
->
List
[
object
]:
def
get_model_class
(
model_name
)
->
List
[
object
]:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录