Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
fbe3c051
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fbe3c051
编写于
12月 30, 2021
作者:
小湉湉
提交者:
GitHub
12月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add style_melgan and hifigan in tts cli, test=tts (#1241)
上级
a232cd8b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
128 addition
and
50 deletion
+128
-50
paddlespeech/cli/tts/infer.py
paddlespeech/cli/tts/infer.py
+63
-19
paddlespeech/t2s/exps/synthesize_e2e.py
paddlespeech/t2s/exps/synthesize_e2e.py
+24
-18
paddlespeech/t2s/frontend/phonectic.py
paddlespeech/t2s/frontend/phonectic.py
+38
-11
paddlespeech/t2s/frontend/zh_normalization/text_normlization.py
...speech/t2s/frontend/zh_normalization/text_normlization.py
+3
-2
未找到文件。
paddlespeech/cli/tts/infer.py
浏览文件 @
fbe3c051
...
@@ -178,6 +178,32 @@ pretrained_models = {
...
@@ -178,6 +178,32 @@ pretrained_models = {
'speech_stats'
:
'speech_stats'
:
'feats_stats.npy'
,
'feats_stats.npy'
,
},
},
# style_melgan
"style_melgan_csmsc-zh"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip'
,
'md5'
:
'5de2d5348f396de0c966926b8c462755'
,
'config'
:
'default.yaml'
,
'ckpt'
:
'snapshot_iter_1500000.pdz'
,
'speech_stats'
:
'feats_stats.npy'
,
},
# hifigan
"hifigan_csmsc-zh"
:
{
'url'
:
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip'
,
'md5'
:
'dd40a3d88dfcf64513fba2f0f961ada6'
,
'config'
:
'default.yaml'
,
'ckpt'
:
'snapshot_iter_2500000.pdz'
,
'speech_stats'
:
'feats_stats.npy'
,
},
}
}
model_alias
=
{
model_alias
=
{
...
@@ -199,6 +225,14 @@ model_alias = {
...
@@ -199,6 +225,14 @@ model_alias = {
"paddlespeech.t2s.models.melgan:MelGANGenerator"
,
"paddlespeech.t2s.models.melgan:MelGANGenerator"
,
"mb_melgan_inference"
:
"mb_melgan_inference"
:
"paddlespeech.t2s.models.melgan:MelGANInference"
,
"paddlespeech.t2s.models.melgan:MelGANInference"
,
"style_melgan"
:
"paddlespeech.t2s.models.melgan:StyleMelGANGenerator"
,
"style_melgan_inference"
:
"paddlespeech.t2s.models.melgan:StyleMelGANInference"
,
"hifigan"
:
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator"
,
"hifigan_inference"
:
"paddlespeech.t2s.models.hifigan:HiFiGANInference"
,
}
}
...
@@ -266,7 +300,7 @@ class TTSExecutor(BaseExecutor):
...
@@ -266,7 +300,7 @@ class TTSExecutor(BaseExecutor):
default
=
'pwgan_csmsc'
,
default
=
'pwgan_csmsc'
,
choices
=
[
choices
=
[
'pwgan_csmsc'
,
'pwgan_ljspeech'
,
'pwgan_aishell3'
,
'pwgan_vctk'
,
'pwgan_csmsc'
,
'pwgan_ljspeech'
,
'pwgan_aishell3'
,
'pwgan_vctk'
,
'mb_melgan_csmsc'
'mb_melgan_csmsc'
,
'style_melgan_csmsc'
,
'hifigan_csmsc'
],
],
help
=
'Choose vocoder type of tts task.'
)
help
=
'Choose vocoder type of tts task.'
)
...
@@ -504,37 +538,47 @@ class TTSExecutor(BaseExecutor):
...
@@ -504,37 +538,47 @@ class TTSExecutor(BaseExecutor):
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_name
=
am
[:
am
.
rindex
(
'_'
)]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
am_dataset
=
am
[
am
.
rindex
(
'_'
)
+
1
:]
get_tone_ids
=
False
get_tone_ids
=
False
merge_sentences
=
False
if
am_name
==
'speedyspeech'
:
if
am_name
==
'speedyspeech'
:
get_tone_ids
=
True
get_tone_ids
=
True
if
lang
==
'zh'
:
if
lang
==
'zh'
:
input_ids
=
self
.
frontend
.
get_input_ids
(
input_ids
=
self
.
frontend
.
get_input_ids
(
text
,
merge_sentences
=
True
,
get_tone_ids
=
get_tone_ids
)
text
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
phone_ids
=
input_ids
[
"phone_ids"
]
phone_ids
=
input_ids
[
"phone_ids"
]
phone_ids
=
phone_ids
[
0
]
if
get_tone_ids
:
if
get_tone_ids
:
tone_ids
=
input_ids
[
"tone_ids"
]
tone_ids
=
input_ids
[
"tone_ids"
]
tone_ids
=
tone_ids
[
0
]
elif
lang
==
'en'
:
elif
lang
==
'en'
:
input_ids
=
self
.
frontend
.
get_input_ids
(
text
)
input_ids
=
self
.
frontend
.
get_input_ids
(
text
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
phone_ids
=
input_ids
[
"phone_ids"
]
else
:
else
:
print
(
"lang should in {'zh', 'en'}!"
)
print
(
"lang should in {'zh', 'en'}!"
)
# am
flags
=
0
if
am_name
==
'speedyspeech'
:
for
i
in
range
(
len
(
phone_ids
)):
mel
=
self
.
am_inference
(
phone_ids
,
tone_ids
)
part_phone_ids
=
phone_ids
[
i
]
# fastspeech2
# am
else
:
if
am_name
==
'speedyspeech'
:
# multi speaker
part_tone_ids
=
tone_ids
[
i
]
if
am_dataset
in
{
"aishell3"
,
"vctk"
}:
mel
=
self
.
am_inference
(
part_phone_ids
,
part_tone_ids
)
mel
=
self
.
am_inference
(
# fastspeech2
phone_ids
,
spk_id
=
paddle
.
to_tensor
(
spk_id
))
else
:
else
:
mel
=
self
.
am_inference
(
phone_ids
)
# multi speaker
if
am_dataset
in
{
"aishell3"
,
"vctk"
}:
# voc
mel
=
self
.
am_inference
(
wav
=
self
.
voc_inference
(
mel
)
part_phone_ids
,
spk_id
=
paddle
.
to_tensor
(
spk_id
))
self
.
_outputs
[
'wav'
]
=
wav
else
:
mel
=
self
.
am_inference
(
part_phone_ids
)
# voc
wav
=
self
.
voc_inference
(
mel
)
if
flags
==
0
:
wav_all
=
wav
flags
=
1
else
:
wav_all
=
paddle
.
concat
([
wav_all
,
wav
])
self
.
_outputs
[
'wav'
]
=
wav_all
def
postprocess
(
self
,
output
:
str
=
'output.wav'
)
->
Union
[
str
,
os
.
PathLike
]:
def
postprocess
(
self
,
output
:
str
=
'output.wav'
)
->
Union
[
str
,
os
.
PathLike
]:
"""
"""
...
...
paddlespeech/t2s/exps/synthesize_e2e.py
浏览文件 @
fbe3c051
...
@@ -196,41 +196,47 @@ def evaluate(args):
...
@@ -196,41 +196,47 @@ def evaluate(args):
output_dir
=
Path
(
args
.
output_dir
)
output_dir
=
Path
(
args
.
output_dir
)
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
merge_sentences
=
False
for
utt_id
,
sentence
in
sentences
:
for
utt_id
,
sentence
in
sentences
:
get_tone_ids
=
False
get_tone_ids
=
False
if
am_name
==
'speedyspeech'
:
if
am_name
==
'speedyspeech'
:
get_tone_ids
=
True
get_tone_ids
=
True
if
args
.
lang
==
'zh'
:
if
args
.
lang
==
'zh'
:
input_ids
=
frontend
.
get_input_ids
(
input_ids
=
frontend
.
get_input_ids
(
sentence
,
merge_sentences
=
True
,
get_tone_ids
=
get_tone_ids
)
sentence
,
merge_sentences
=
merge_sentences
,
get_tone_ids
=
get_tone_ids
)
phone_ids
=
input_ids
[
"phone_ids"
]
phone_ids
=
input_ids
[
"phone_ids"
]
phone_ids
=
phone_ids
[
0
]
if
get_tone_ids
:
if
get_tone_ids
:
tone_ids
=
input_ids
[
"tone_ids"
]
tone_ids
=
input_ids
[
"tone_ids"
]
tone_ids
=
tone_ids
[
0
]
elif
args
.
lang
==
'en'
:
elif
args
.
lang
==
'en'
:
input_ids
=
frontend
.
get_input_ids
(
sentence
)
input_ids
=
frontend
.
get_input_ids
(
sentence
,
merge_sentences
=
merge_sentences
)
phone_ids
=
input_ids
[
"phone_ids"
]
phone_ids
=
input_ids
[
"phone_ids"
]
else
:
else
:
print
(
"lang should in {'zh', 'en'}!"
)
print
(
"lang should in {'zh', 'en'}!"
)
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
# acoustic model
flags
=
0
if
am_name
==
'fastspeech2'
:
for
i
in
range
(
len
(
phone_ids
)):
# multi speaker
part_phone_ids
=
phone_ids
[
i
]
if
am_dataset
in
{
"aishell3"
,
"vctk"
}:
# acoustic model
spk_id
=
paddle
.
to_tensor
(
args
.
spk_id
)
if
am_name
==
'fastspeech2'
:
mel
=
am_inference
(
phone_ids
,
spk_id
)
# multi speaker
if
am_dataset
in
{
"aishell3"
,
"vctk"
}:
spk_id
=
paddle
.
to_tensor
(
args
.
spk_id
)
mel
=
am_inference
(
part_phone_ids
,
spk_id
)
else
:
mel
=
am_inference
(
part_phone_ids
)
elif
am_name
==
'speedyspeech'
:
part_tone_ids
=
tone_ids
[
i
]
mel
=
am_inference
(
part_phone_ids
,
part_tone_ids
)
# vocoder
wav
=
voc_inference
(
mel
)
if
flags
==
0
:
wav_all
=
wav
flags
=
1
else
:
else
:
mel
=
am_inference
(
phone_ids
)
wav_all
=
paddle
.
concat
([
wav_all
,
wav
])
elif
am_name
==
'speedyspeech'
:
mel
=
am_inference
(
phone_ids
,
tone_ids
)
# vocoder
wav
=
voc_inference
(
mel
)
sf
.
write
(
sf
.
write
(
str
(
output_dir
/
(
utt_id
+
".wav"
)),
str
(
output_dir
/
(
utt_id
+
".wav"
)),
wav
.
numpy
(),
wav
_all
.
numpy
(),
samplerate
=
am_config
.
fs
)
samplerate
=
am_config
.
fs
)
print
(
f
"
{
utt_id
}
done!"
)
print
(
f
"
{
utt_id
}
done!"
)
...
...
paddlespeech/t2s/frontend/phonectic.py
浏览文件 @
fbe3c051
...
@@ -13,7 +13,9 @@
...
@@ -13,7 +13,9 @@
# limitations under the License.
# limitations under the License.
from
abc
import
ABC
from
abc
import
ABC
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
List
import
numpy
as
np
import
paddle
import
paddle
from
g2p_en
import
G2p
from
g2p_en
import
G2p
from
g2pM
import
G2pM
from
g2pM
import
G2pM
...
@@ -21,6 +23,7 @@ from g2pM import G2pM
...
@@ -21,6 +23,7 @@ from g2pM import G2pM
from
paddlespeech.t2s.frontend.normalizer.normalizer
import
normalize
from
paddlespeech.t2s.frontend.normalizer.normalizer
import
normalize
from
paddlespeech.t2s.frontend.punctuation
import
get_punctuations
from
paddlespeech.t2s.frontend.punctuation
import
get_punctuations
from
paddlespeech.t2s.frontend.vocab
import
Vocab
from
paddlespeech.t2s.frontend.vocab
import
Vocab
from
paddlespeech.t2s.frontend.zh_normalization.text_normlization
import
TextNormalizer
# discard opencc untill we find an easy solution to install it on windows
# discard opencc untill we find an easy solution to install it on windows
# from opencc import OpenCC
# from opencc import OpenCC
...
@@ -53,6 +56,7 @@ class English(Phonetics):
...
@@ -53,6 +56,7 @@ class English(Phonetics):
self
.
vocab
=
Vocab
(
self
.
phonemes
+
self
.
punctuations
)
self
.
vocab
=
Vocab
(
self
.
phonemes
+
self
.
punctuations
)
self
.
vocab_phones
=
{}
self
.
vocab_phones
=
{}
self
.
punc
=
":,;。?!“”‘’':,;.?!"
self
.
punc
=
":,;。?!“”‘’':,;.?!"
self
.
text_normalizer
=
TextNormalizer
()
if
phone_vocab_path
:
if
phone_vocab_path
:
with
open
(
phone_vocab_path
,
'rt'
)
as
f
:
with
open
(
phone_vocab_path
,
'rt'
)
as
f
:
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
phn_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
...
@@ -78,19 +82,42 @@ class English(Phonetics):
...
@@ -78,19 +82,42 @@ class English(Phonetics):
phonemes
=
[
item
for
item
in
phonemes
if
item
in
self
.
vocab
.
stoi
]
phonemes
=
[
item
for
item
in
phonemes
if
item
in
self
.
vocab
.
stoi
]
return
phonemes
return
phonemes
def
get_input_ids
(
self
,
sentence
:
str
)
->
paddle
.
Tensor
:
def
_p2id
(
self
,
phonemes
:
List
[
str
])
->
np
.
array
:
result
=
{}
# replace unk phone with sp
phones
=
self
.
phoneticize
(
sentence
)
phonemes
=
[
# remove start_symbol and end_symbol
phones
=
phones
[
1
:
-
1
]
phones
=
[
phn
for
phn
in
phones
if
not
phn
.
isspace
()]
phones
=
[
phn
if
(
phn
in
self
.
vocab_phones
and
phn
not
in
self
.
punc
)
else
"sp"
phn
if
(
phn
in
self
.
vocab_phones
and
phn
not
in
self
.
punc
)
else
"sp"
for
phn
in
phones
for
phn
in
phone
me
s
]
]
phone_ids
=
[
self
.
vocab_phones
[
phn
]
for
phn
in
phones
]
phone_ids
=
[
self
.
vocab_phones
[
item
]
for
item
in
phonemes
]
phone_ids
=
paddle
.
to_tensor
(
phone_ids
)
return
np
.
array
(
phone_ids
,
np
.
int64
)
result
[
"phone_ids"
]
=
phone_ids
def
get_input_ids
(
self
,
sentence
:
str
,
merge_sentences
:
bool
=
False
)
->
paddle
.
Tensor
:
result
=
{}
sentences
=
self
.
text_normalizer
.
_split
(
sentence
,
lang
=
"en"
)
phones_list
=
[]
temp_phone_ids
=
[]
for
sentence
in
sentences
:
phones
=
self
.
phoneticize
(
sentence
)
# remove start_symbol and end_symbol
phones
=
phones
[
1
:
-
1
]
phones
=
[
phn
for
phn
in
phones
if
not
phn
.
isspace
()]
phones_list
.
append
(
phones
)
if
merge_sentences
:
merge_list
=
sum
(
phones_list
,
[])
# rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end
if
merge_list
[
-
1
]
==
'sp'
:
merge_list
=
merge_list
[:
-
1
]
phones_list
=
[]
phones_list
.
append
(
merge_list
)
for
part_phones_list
in
phones_list
:
phone_ids
=
self
.
_p2id
(
part_phones_list
)
phone_ids
=
paddle
.
to_tensor
(
phone_ids
)
temp_phone_ids
.
append
(
phone_ids
)
result
[
"phone_ids"
]
=
temp_phone_ids
return
result
return
result
def
numericalize
(
self
,
phonemes
):
def
numericalize
(
self
,
phonemes
):
...
...
paddlespeech/t2s/frontend/zh_normalization/text_normlization.py
浏览文件 @
fbe3c051
...
@@ -53,7 +53,7 @@ class TextNormalizer():
...
@@ -53,7 +53,7 @@ class TextNormalizer():
def
__init__
(
self
):
def
__init__
(
self
):
self
.
SENTENCE_SPLITOR
=
re
.
compile
(
r
'([:,;。?!,;?!][”’]?)'
)
self
.
SENTENCE_SPLITOR
=
re
.
compile
(
r
'([:,;。?!,;?!][”’]?)'
)
def
_split
(
self
,
text
:
str
)
->
List
[
str
]:
def
_split
(
self
,
text
:
str
,
lang
=
"zh"
)
->
List
[
str
]:
"""Split long text into sentences with sentence-splitting punctuations.
"""Split long text into sentences with sentence-splitting punctuations.
Parameters
Parameters
----------
----------
...
@@ -65,7 +65,8 @@ class TextNormalizer():
...
@@ -65,7 +65,8 @@ class TextNormalizer():
Sentences.
Sentences.
"""
"""
# Only for pure Chinese here
# Only for pure Chinese here
text
=
text
.
replace
(
" "
,
""
)
if
lang
==
"zh"
:
text
=
text
.
replace
(
" "
,
""
)
text
=
self
.
SENTENCE_SPLITOR
.
sub
(
r
'\1\n'
,
text
)
text
=
self
.
SENTENCE_SPLITOR
.
sub
(
r
'\1\n'
,
text
)
text
=
text
.
strip
()
text
=
text
.
strip
()
sentences
=
[
sentence
.
strip
()
for
sentence
in
re
.
split
(
r
'\n+'
,
text
)]
sentences
=
[
sentence
.
strip
()
for
sentence
in
re
.
split
(
r
'\n+'
,
text
)]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录