Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
98ce69d0
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
98ce69d0
编写于
1月 05, 2022
作者:
小湉湉
提交者:
GitHub
1月 05, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1259 from jerryuhoo/develop
[TTS]Add multi-speaker support for the SpeedySpeech model
上级
4cab9f62
1323242e
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
125 addition
and
28 deletion
+125
-28
examples/csmsc/tts2/local/preprocess.sh
examples/csmsc/tts2/local/preprocess.sh
+3
-0
paddlespeech/t2s/datasets/am_batch_fn.py
paddlespeech/t2s/datasets/am_batch_fn.py
+41
-1
paddlespeech/t2s/exps/speedyspeech/normalize.py
paddlespeech/t2s/exps/speedyspeech/normalize.py
+10
-1
paddlespeech/t2s/exps/speedyspeech/preprocess.py
paddlespeech/t2s/exps/speedyspeech/preprocess.py
+5
-1
paddlespeech/t2s/exps/speedyspeech/train.py
paddlespeech/t2s/exps/speedyspeech/train.py
+28
-10
paddlespeech/t2s/models/speedyspeech/speedyspeech.py
paddlespeech/t2s/models/speedyspeech/speedyspeech.py
+27
-13
paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py
paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py
+11
-2
未找到文件。
examples/csmsc/tts2/local/preprocess.sh
浏览文件 @
98ce69d0
...
@@ -45,6 +45,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
...
@@ -45,6 +45,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--stats
=
dump/train/feats_stats.npy
\
--stats
=
dump/train/feats_stats.npy
\
--phones-dict
=
dump/phone_id_map.txt
\
--phones-dict
=
dump/phone_id_map.txt
\
--tones-dict
=
dump/tone_id_map.txt
\
--tones-dict
=
dump/tone_id_map.txt
\
--speaker-dict
=
dump/speaker_id_map.txt
\
--use-relative-path
=
True
--use-relative-path
=
True
python3
${
BIN_DIR
}
/normalize.py
\
python3
${
BIN_DIR
}
/normalize.py
\
...
@@ -53,6 +54,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
...
@@ -53,6 +54,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--stats
=
dump/train/feats_stats.npy
\
--stats
=
dump/train/feats_stats.npy
\
--phones-dict
=
dump/phone_id_map.txt
\
--phones-dict
=
dump/phone_id_map.txt
\
--tones-dict
=
dump/tone_id_map.txt
\
--tones-dict
=
dump/tone_id_map.txt
\
--speaker-dict
=
dump/speaker_id_map.txt
\
--use-relative-path
=
True
--use-relative-path
=
True
python3
${
BIN_DIR
}
/normalize.py
\
python3
${
BIN_DIR
}
/normalize.py
\
...
@@ -61,6 +63,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
...
@@ -61,6 +63,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--stats
=
dump/train/feats_stats.npy
\
--stats
=
dump/train/feats_stats.npy
\
--phones-dict
=
dump/phone_id_map.txt
\
--phones-dict
=
dump/phone_id_map.txt
\
--tones-dict
=
dump/tone_id_map.txt
\
--tones-dict
=
dump/tone_id_map.txt
\
--speaker-dict
=
dump/speaker_id_map.txt
\
--use-relative-path
=
True
--use-relative-path
=
True
fi
fi
paddlespeech/t2s/datasets/am_batch_fn.py
浏览文件 @
98ce69d0
...
@@ -17,7 +17,7 @@ import paddle
...
@@ -17,7 +17,7 @@ import paddle
from
paddlespeech.t2s.data.batch
import
batch_sequences
from
paddlespeech.t2s.data.batch
import
batch_sequences
def
speedyspeech_batch_fn
(
examples
):
def
speedyspeech_
single_spk_
batch_fn
(
examples
):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
phones
=
[
np
.
array
(
item
[
"phones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
phones
=
[
np
.
array
(
item
[
"phones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
tones
=
[
np
.
array
(
item
[
"tones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
tones
=
[
np
.
array
(
item
[
"tones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
...
@@ -54,6 +54,46 @@ def speedyspeech_batch_fn(examples):
...
@@ -54,6 +54,46 @@ def speedyspeech_batch_fn(examples):
}
}
return
batch
return
batch
def
speedyspeech_multi_spk_batch_fn
(
examples
):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
phones
=
[
np
.
array
(
item
[
"phones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
tones
=
[
np
.
array
(
item
[
"tones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
feats
=
[
np
.
array
(
item
[
"feats"
],
dtype
=
np
.
float32
)
for
item
in
examples
]
durations
=
[
np
.
array
(
item
[
"durations"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
num_phones
=
[
np
.
array
(
item
[
"num_phones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
num_frames
=
[
np
.
array
(
item
[
"num_frames"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
phones
=
batch_sequences
(
phones
)
tones
=
batch_sequences
(
tones
)
feats
=
batch_sequences
(
feats
)
durations
=
batch_sequences
(
durations
)
# convert each batch to paddle.Tensor
phones
=
paddle
.
to_tensor
(
phones
)
tones
=
paddle
.
to_tensor
(
tones
)
feats
=
paddle
.
to_tensor
(
feats
)
durations
=
paddle
.
to_tensor
(
durations
)
num_phones
=
paddle
.
to_tensor
(
num_phones
)
num_frames
=
paddle
.
to_tensor
(
num_frames
)
batch
=
{
"phones"
:
phones
,
"tones"
:
tones
,
"num_phones"
:
num_phones
,
"num_frames"
:
num_frames
,
"feats"
:
feats
,
"durations"
:
durations
,
}
if
"spk_id"
in
examples
[
0
]:
spk_id
=
[
np
.
array
(
item
[
"spk_id"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
spk_id
=
paddle
.
to_tensor
(
spk_id
)
batch
[
"spk_id"
]
=
spk_id
return
batch
def
fastspeech2_single_spk_batch_fn
(
examples
):
def
fastspeech2_single_spk_batch_fn
(
examples
):
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]
...
...
paddlespeech/t2s/exps/speedyspeech/normalize.py
浏览文件 @
98ce69d0
...
@@ -47,7 +47,8 @@ def main():
...
@@ -47,7 +47,8 @@ def main():
"--phones-dict"
,
type
=
str
,
default
=
None
,
help
=
"phone vocabulary file."
)
"--phones-dict"
,
type
=
str
,
default
=
None
,
help
=
"phone vocabulary file."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--tones-dict"
,
type
=
str
,
default
=
None
,
help
=
"tone vocabulary file."
)
"--tones-dict"
,
type
=
str
,
default
=
None
,
help
=
"tone vocabulary file."
)
parser
.
add_argument
(
"--speaker-dict"
,
type
=
str
,
default
=
None
,
help
=
"speaker id map file."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--verbose"
,
"--verbose"
,
type
=
int
,
type
=
int
,
...
@@ -121,6 +122,12 @@ def main():
...
@@ -121,6 +122,12 @@ def main():
for
tone
,
id
in
tone_id
:
for
tone
,
id
in
tone_id
:
vocab_tones
[
tone
]
=
int
(
id
)
vocab_tones
[
tone
]
=
int
(
id
)
vocab_speaker
=
{}
with
open
(
args
.
speaker_dict
,
'rt'
)
as
f
:
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
for
spk
,
id
in
spk_id
:
vocab_speaker
[
spk
]
=
int
(
id
)
# process each file
# process each file
output_metadata
=
[]
output_metadata
=
[]
...
@@ -135,11 +142,13 @@ def main():
...
@@ -135,11 +142,13 @@ def main():
np
.
save
(
mel_path
,
mel
.
astype
(
np
.
float32
),
allow_pickle
=
False
)
np
.
save
(
mel_path
,
mel
.
astype
(
np
.
float32
),
allow_pickle
=
False
)
phone_ids
=
[
vocab_phones
[
p
]
for
p
in
item
[
'phones'
]]
phone_ids
=
[
vocab_phones
[
p
]
for
p
in
item
[
'phones'
]]
tone_ids
=
[
vocab_tones
[
p
]
for
p
in
item
[
'tones'
]]
tone_ids
=
[
vocab_tones
[
p
]
for
p
in
item
[
'tones'
]]
spk_id
=
vocab_speaker
[
item
[
"speaker"
]]
if
args
.
use_relative_path
:
if
args
.
use_relative_path
:
# convert absolute path to relative path:
# convert absolute path to relative path:
mel_path
=
mel_path
.
relative_to
(
dumpdir
)
mel_path
=
mel_path
.
relative_to
(
dumpdir
)
output_metadata
.
append
({
output_metadata
.
append
({
'utt_id'
:
utt_id
,
'utt_id'
:
utt_id
,
"spk_id"
:
spk_id
,
'phones'
:
phone_ids
,
'phones'
:
phone_ids
,
'tones'
:
tone_ids
,
'tones'
:
tone_ids
,
'num_phones'
:
item
[
'num_phones'
],
'num_phones'
:
item
[
'num_phones'
],
...
...
paddlespeech/t2s/exps/speedyspeech/preprocess.py
浏览文件 @
98ce69d0
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
import
re
import
re
import
os
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
operator
import
itemgetter
from
operator
import
itemgetter
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -32,7 +33,7 @@ from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_
...
@@ -32,7 +33,7 @@ from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_
from
paddlespeech.t2s.datasets.preprocess_utils
import
get_phn_dur
from
paddlespeech.t2s.datasets.preprocess_utils
import
get_phn_dur
from
paddlespeech.t2s.datasets.preprocess_utils
import
get_phones_tones
from
paddlespeech.t2s.datasets.preprocess_utils
import
get_phones_tones
from
paddlespeech.t2s.datasets.preprocess_utils
import
merge_silence
from
paddlespeech.t2s.datasets.preprocess_utils
import
merge_silence
from
paddlespeech.t2s.datasets.preprocess_utils
import
get_spk_id_map
def
process_sentence
(
config
:
Dict
[
str
,
Any
],
def
process_sentence
(
config
:
Dict
[
str
,
Any
],
fp
:
Path
,
fp
:
Path
,
...
@@ -101,6 +102,7 @@ def process_sentence(config: Dict[str, Any],
...
@@ -101,6 +102,7 @@ def process_sentence(config: Dict[str, Any],
"utt_id"
:
utt_id
,
"utt_id"
:
utt_id
,
"phones"
:
phones
,
"phones"
:
phones
,
"tones"
:
tones
,
"tones"
:
tones
,
"speaker"
:
speaker
,
"num_phones"
:
len
(
phones
),
"num_phones"
:
len
(
phones
),
"num_frames"
:
num_frames
,
"num_frames"
:
num_frames
,
"durations"
:
durations
,
"durations"
:
durations
,
...
@@ -229,6 +231,8 @@ def main():
...
@@ -229,6 +231,8 @@ def main():
tone_id_map_path
=
dumpdir
/
"tone_id_map.txt"
tone_id_map_path
=
dumpdir
/
"tone_id_map.txt"
get_phones_tones
(
sentences
,
phone_id_map_path
,
tone_id_map_path
,
get_phones_tones
(
sentences
,
phone_id_map_path
,
tone_id_map_path
,
args
.
dataset
)
args
.
dataset
)
speaker_id_map_path
=
dumpdir
/
"speaker_id_map.txt"
get_spk_id_map
(
speaker_set
,
speaker_id_map_path
)
if
args
.
dataset
==
"baker"
:
if
args
.
dataset
==
"baker"
:
wav_files
=
sorted
(
list
((
rootdir
/
"Wave"
).
rglob
(
"*.wav"
)))
wav_files
=
sorted
(
list
((
rootdir
/
"Wave"
).
rglob
(
"*.wav"
)))
...
...
paddlespeech/t2s/exps/speedyspeech/train.py
浏览文件 @
98ce69d0
...
@@ -27,7 +27,8 @@ from paddle.io import DataLoader
...
@@ -27,7 +27,8 @@ from paddle.io import DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddle.io
import
DistributedBatchSampler
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.datasets.am_batch_fn
import
speedyspeech_batch_fn
from
paddlespeech.t2s.datasets.am_batch_fn
import
speedyspeech_single_spk_batch_fn
from
paddlespeech.t2s.datasets.am_batch_fn
import
speedyspeech_multi_spk_batch_fn
from
paddlespeech.t2s.datasets.data_table
import
DataTable
from
paddlespeech.t2s.datasets.data_table
import
DataTable
from
paddlespeech.t2s.models.speedyspeech
import
SpeedySpeech
from
paddlespeech.t2s.models.speedyspeech
import
SpeedySpeech
from
paddlespeech.t2s.models.speedyspeech
import
SpeedySpeechEvaluator
from
paddlespeech.t2s.models.speedyspeech
import
SpeedySpeechEvaluator
...
@@ -57,6 +58,21 @@ def train_sp(args, config):
...
@@ -57,6 +58,21 @@ def train_sp(args, config):
f
"rank:
{
dist
.
get_rank
()
}
, pid:
{
os
.
getpid
()
}
, parent_pid:
{
os
.
getppid
()
}
"
,
f
"rank:
{
dist
.
get_rank
()
}
, pid:
{
os
.
getpid
()
}
, parent_pid:
{
os
.
getppid
()
}
"
,
)
)
fields
=
[
"phones"
,
"tones"
,
"num_phones"
,
"num_frames"
,
"feats"
,
"durations"
]
spk_num
=
None
if
args
.
speaker_dict
is
not
None
:
print
(
"multiple speaker speedyspeech!"
)
collate_fn
=
speedyspeech_multi_spk_batch_fn
with
open
(
args
.
speaker_dict
,
'rt'
)
as
f
:
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
spk_num
=
len
(
spk_id
)
fields
+=
[
"spk_id"
]
else
:
print
(
"single speaker speedyspeech!"
)
collate_fn
=
speedyspeech_single_spk_batch_fn
print
(
"spk_num:"
,
spk_num
)
# dataloader has been too verbose
# dataloader has been too verbose
logging
.
getLogger
(
"DataLoader"
).
disabled
=
True
logging
.
getLogger
(
"DataLoader"
).
disabled
=
True
...
@@ -71,9 +87,7 @@ def train_sp(args, config):
...
@@ -71,9 +87,7 @@ def train_sp(args, config):
train_dataset
=
DataTable
(
train_dataset
=
DataTable
(
data
=
train_metadata
,
data
=
train_metadata
,
fields
=
[
fields
=
fields
,
"phones"
,
"tones"
,
"num_phones"
,
"num_frames"
,
"feats"
,
"durations"
],
converters
=
{
converters
=
{
"feats"
:
np
.
load
,
"feats"
:
np
.
load
,
},
)
},
)
...
@@ -87,9 +101,7 @@ def train_sp(args, config):
...
@@ -87,9 +101,7 @@ def train_sp(args, config):
dev_dataset
=
DataTable
(
dev_dataset
=
DataTable
(
data
=
dev_metadata
,
data
=
dev_metadata
,
fields
=
[
fields
=
fields
,
"phones"
,
"tones"
,
"num_phones"
,
"num_frames"
,
"feats"
,
"durations"
],
converters
=
{
converters
=
{
"feats"
:
np
.
load
,
"feats"
:
np
.
load
,
},
)
},
)
...
@@ -105,14 +117,14 @@ def train_sp(args, config):
...
@@ -105,14 +117,14 @@ def train_sp(args, config):
train_dataloader
=
DataLoader
(
train_dataloader
=
DataLoader
(
train_dataset
,
train_dataset
,
batch_sampler
=
train_sampler
,
batch_sampler
=
train_sampler
,
collate_fn
=
speedyspeech_batch
_fn
,
collate_fn
=
collate
_fn
,
num_workers
=
config
.
num_workers
)
num_workers
=
config
.
num_workers
)
dev_dataloader
=
DataLoader
(
dev_dataloader
=
DataLoader
(
dev_dataset
,
dev_dataset
,
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
,
drop_last
=
False
,
batch_size
=
config
.
batch_size
,
batch_size
=
config
.
batch_size
,
collate_fn
=
speedyspeech_batch
_fn
,
collate_fn
=
collate
_fn
,
num_workers
=
config
.
num_workers
)
num_workers
=
config
.
num_workers
)
print
(
"dataloaders done!"
)
print
(
"dataloaders done!"
)
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
...
@@ -125,7 +137,7 @@ def train_sp(args, config):
...
@@ -125,7 +137,7 @@ def train_sp(args, config):
print
(
"tone_size:"
,
tone_size
)
print
(
"tone_size:"
,
tone_size
)
model
=
SpeedySpeech
(
model
=
SpeedySpeech
(
vocab_size
=
vocab_size
,
tone_size
=
tone_size
,
**
config
[
"model"
])
vocab_size
=
vocab_size
,
tone_size
=
tone_size
,
spk_num
=
spk_num
,
**
config
[
"model"
])
if
world_size
>
1
:
if
world_size
>
1
:
model
=
DataParallel
(
model
)
model
=
DataParallel
(
model
)
print
(
"model done!"
)
print
(
"model done!"
)
...
@@ -184,6 +196,12 @@ def main():
...
@@ -184,6 +196,12 @@ def main():
parser
.
add_argument
(
parser
.
add_argument
(
"--tones-dict"
,
type
=
str
,
default
=
None
,
help
=
"tone vocabulary file."
)
"--tones-dict"
,
type
=
str
,
default
=
None
,
help
=
"tone vocabulary file."
)
parser
.
add_argument
(
"--speaker-dict"
,
type
=
str
,
default
=
None
,
help
=
"speaker id map file for multiple speaker model."
)
# 这里可以多传入 max_epoch 等
# 这里可以多传入 max_epoch 等
args
,
rest
=
parser
.
parse_known_args
()
args
,
rest
=
parser
.
parse_known_args
()
...
...
paddlespeech/t2s/models/speedyspeech/speedyspeech.py
浏览文件 @
98ce69d0
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
import
paddle.nn.functional
as
F
from
paddlespeech.t2s.modules.positional_encoding
import
sinusoid_position_encoding
from
paddlespeech.t2s.modules.positional_encoding
import
sinusoid_position_encoding
...
@@ -96,7 +96,7 @@ class TextEmbedding(nn.Layer):
...
@@ -96,7 +96,7 @@ class TextEmbedding(nn.Layer):
class
SpeedySpeechEncoder
(
nn
.
Layer
):
class
SpeedySpeechEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
vocab_size
,
tone_size
,
hidden_size
,
kernel_size
,
def
__init__
(
self
,
vocab_size
,
tone_size
,
hidden_size
,
kernel_size
,
dilations
):
dilations
,
spk_num
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
embedding
=
TextEmbedding
(
self
.
embedding
=
TextEmbedding
(
vocab_size
,
vocab_size
,
...
@@ -104,6 +104,15 @@ class SpeedySpeechEncoder(nn.Layer):
...
@@ -104,6 +104,15 @@ class SpeedySpeechEncoder(nn.Layer):
tone_size
,
tone_size
,
padding_idx
=
0
,
padding_idx
=
0
,
tone_padding_idx
=
0
)
tone_padding_idx
=
0
)
if
spk_num
:
self
.
spk_emb
=
nn
.
Embedding
(
num_embeddings
=
spk_num
,
embedding_dim
=
hidden_size
,
padding_idx
=
0
)
else
:
self
.
spk_emb
=
None
self
.
prenet
=
nn
.
Sequential
(
self
.
prenet
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_size
,
hidden_size
),
nn
.
Linear
(
hidden_size
,
hidden_size
),
nn
.
ReLU
(),
)
nn
.
ReLU
(),
)
...
@@ -118,8 +127,10 @@ class SpeedySpeechEncoder(nn.Layer):
...
@@ -118,8 +127,10 @@ class SpeedySpeechEncoder(nn.Layer):
nn
.
BatchNorm1D
(
hidden_size
,
data_format
=
"NLC"
),
nn
.
BatchNorm1D
(
hidden_size
,
data_format
=
"NLC"
),
nn
.
Linear
(
hidden_size
,
hidden_size
),
)
nn
.
Linear
(
hidden_size
,
hidden_size
),
)
def
forward
(
self
,
text
,
tones
):
def
forward
(
self
,
text
,
tones
,
spk_id
=
None
):
embedding
=
self
.
embedding
(
text
,
tones
)
embedding
=
self
.
embedding
(
text
,
tones
)
if
self
.
spk_emb
:
embedding
+=
self
.
spk_emb
(
spk_id
).
unsqueeze
(
1
)
embedding
=
self
.
prenet
(
embedding
)
embedding
=
self
.
prenet
(
embedding
)
x
=
self
.
res_blocks
(
embedding
)
x
=
self
.
res_blocks
(
embedding
)
x
=
embedding
+
self
.
postnet1
(
x
)
x
=
embedding
+
self
.
postnet1
(
x
)
...
@@ -171,11 +182,12 @@ class SpeedySpeech(nn.Layer):
...
@@ -171,11 +182,12 @@ class SpeedySpeech(nn.Layer):
decoder_output_size
,
decoder_output_size
,
decoder_kernel_size
,
decoder_kernel_size
,
decoder_dilations
,
decoder_dilations
,
tone_size
=
None
,
):
tone_size
=
None
,
spk_num
=
None
):
super
().
__init__
()
super
().
__init__
()
encoder
=
SpeedySpeechEncoder
(
vocab_size
,
tone_size
,
encoder
=
SpeedySpeechEncoder
(
vocab_size
,
tone_size
,
encoder_hidden_size
,
encoder_kernel_size
,
encoder_hidden_size
,
encoder_kernel_size
,
encoder_dilations
)
encoder_dilations
,
spk_num
)
duration_predictor
=
DurationPredictor
(
duration_predictor_hidden_size
)
duration_predictor
=
DurationPredictor
(
duration_predictor_hidden_size
)
decoder
=
SpeedySpeechDecoder
(
decoder_hidden_size
,
decoder_output_size
,
decoder
=
SpeedySpeechDecoder
(
decoder_hidden_size
,
decoder_output_size
,
decoder_kernel_size
,
decoder_dilations
)
decoder_kernel_size
,
decoder_dilations
)
...
@@ -184,13 +196,15 @@ class SpeedySpeech(nn.Layer):
...
@@ -184,13 +196,15 @@ class SpeedySpeech(nn.Layer):
self
.
duration_predictor
=
duration_predictor
self
.
duration_predictor
=
duration_predictor
self
.
decoder
=
decoder
self
.
decoder
=
decoder
def
forward
(
self
,
text
,
tones
,
durations
):
def
forward
(
self
,
text
,
tones
,
durations
,
spk_id
:
paddle
.
Tensor
=
None
):
# input of embedding must be int64
# input of embedding must be int64
text
=
paddle
.
cast
(
text
,
'int64'
)
text
=
paddle
.
cast
(
text
,
'int64'
)
tones
=
paddle
.
cast
(
tones
,
'int64'
)
tones
=
paddle
.
cast
(
tones
,
'int64'
)
if
spk_id
is
not
None
:
spk_id
=
paddle
.
cast
(
spk_id
,
'int64'
)
durations
=
paddle
.
cast
(
durations
,
'int64'
)
durations
=
paddle
.
cast
(
durations
,
'int64'
)
encodings
=
self
.
encoder
(
text
,
tones
)
encodings
=
self
.
encoder
(
text
,
tones
,
spk_id
)
# (B, T)
pred_durations
=
self
.
duration_predictor
(
encodings
.
detach
())
pred_durations
=
self
.
duration_predictor
(
encodings
.
detach
())
# expand encodings
# expand encodings
...
@@ -204,7 +218,7 @@ class SpeedySpeech(nn.Layer):
...
@@ -204,7 +218,7 @@ class SpeedySpeech(nn.Layer):
decoded
=
self
.
decoder
(
encodings
)
decoded
=
self
.
decoder
(
encodings
)
return
decoded
,
pred_durations
return
decoded
,
pred_durations
def
inference
(
self
,
text
,
tones
=
None
):
def
inference
(
self
,
text
,
tones
=
None
,
spk_id
=
None
):
# text: [T]
# text: [T]
# tones: [T]
# tones: [T]
# input of embedding must be int64
# input of embedding must be int64
...
@@ -214,7 +228,8 @@ class SpeedySpeech(nn.Layer):
...
@@ -214,7 +228,8 @@ class SpeedySpeech(nn.Layer):
tones
=
paddle
.
cast
(
tones
,
'int64'
)
tones
=
paddle
.
cast
(
tones
,
'int64'
)
tones
=
tones
.
unsqueeze
(
0
)
tones
=
tones
.
unsqueeze
(
0
)
encodings
=
self
.
encoder
(
text
,
tones
)
encodings
=
self
.
encoder
(
text
,
tones
,
spk_id
)
pred_durations
=
self
.
duration_predictor
(
encodings
)
# (1, T)
pred_durations
=
self
.
duration_predictor
(
encodings
)
# (1, T)
durations_to_expand
=
paddle
.
round
(
pred_durations
.
exp
())
durations_to_expand
=
paddle
.
round
(
pred_durations
.
exp
())
durations_to_expand
=
(
durations_to_expand
).
astype
(
paddle
.
int64
)
durations_to_expand
=
(
durations_to_expand
).
astype
(
paddle
.
int64
)
...
@@ -240,14 +255,13 @@ class SpeedySpeech(nn.Layer):
...
@@ -240,14 +255,13 @@ class SpeedySpeech(nn.Layer):
decoded
=
self
.
decoder
(
encodings
)
decoded
=
self
.
decoder
(
encodings
)
return
decoded
[
0
]
return
decoded
[
0
]
class
SpeedySpeechInference
(
nn
.
Layer
):
class
SpeedySpeechInference
(
nn
.
Layer
):
def
__init__
(
self
,
normalizer
,
speedyspeech_model
):
def
__init__
(
self
,
normalizer
,
speedyspeech_model
):
super
().
__init__
()
super
().
__init__
()
self
.
normalizer
=
normalizer
self
.
normalizer
=
normalizer
self
.
acoustic_model
=
speedyspeech_model
self
.
acoustic_model
=
speedyspeech_model
def
forward
(
self
,
phones
,
tones
):
def
forward
(
self
,
phones
,
tones
,
spk_id
=
None
):
normalized_mel
=
self
.
acoustic_model
.
inference
(
phones
,
tones
)
normalized_mel
=
self
.
acoustic_model
.
inference
(
phones
,
tones
,
spk_id
)
logmel
=
self
.
normalizer
.
inverse
(
normalized_mel
)
logmel
=
self
.
normalizer
.
inverse
(
normalized_mel
)
return
logmel
return
logmel
paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py
浏览文件 @
98ce69d0
...
@@ -50,10 +50,15 @@ class SpeedySpeechUpdater(StandardUpdater):
...
@@ -50,10 +50,15 @@ class SpeedySpeechUpdater(StandardUpdater):
self
.
msg
=
"Rank: {}, "
.
format
(
dist
.
get_rank
())
self
.
msg
=
"Rank: {}, "
.
format
(
dist
.
get_rank
())
losses_dict
=
{}
losses_dict
=
{}
# spk_id!=None in multiple spk speedyspeech
spk_id
=
batch
[
"spk_id"
]
if
"spk_id"
in
batch
else
None
decoded
,
predicted_durations
=
self
.
model
(
decoded
,
predicted_durations
=
self
.
model
(
text
=
batch
[
"phones"
],
text
=
batch
[
"phones"
],
tones
=
batch
[
"tones"
],
tones
=
batch
[
"tones"
],
durations
=
batch
[
"durations"
])
durations
=
batch
[
"durations"
],
spk_id
=
spk_id
)
target_mel
=
batch
[
"feats"
]
target_mel
=
batch
[
"feats"
]
spec_mask
=
F
.
sequence_mask
(
spec_mask
=
F
.
sequence_mask
(
...
@@ -112,10 +117,14 @@ class SpeedySpeechEvaluator(StandardEvaluator):
...
@@ -112,10 +117,14 @@ class SpeedySpeechEvaluator(StandardEvaluator):
self
.
msg
=
"Evaluate: "
self
.
msg
=
"Evaluate: "
losses_dict
=
{}
losses_dict
=
{}
spk_id
=
batch
[
"spk_id"
]
if
"spk_id"
in
batch
else
None
decoded
,
predicted_durations
=
self
.
model
(
decoded
,
predicted_durations
=
self
.
model
(
text
=
batch
[
"phones"
],
text
=
batch
[
"phones"
],
tones
=
batch
[
"tones"
],
tones
=
batch
[
"tones"
],
durations
=
batch
[
"durations"
])
durations
=
batch
[
"durations"
],
spk_id
=
spk_id
)
target_mel
=
batch
[
"feats"
]
target_mel
=
batch
[
"feats"
]
spec_mask
=
F
.
sequence_mask
(
spec_mask
=
F
.
sequence_mask
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录