Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
11991b6d
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
11991b6d
编写于
12月 31, 2021
作者:
J
Jerryuhoo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add multi-speaker support for speedyspeech
上级
f27d9d50
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
242 addition
and
22 deletion
+242
-22
examples/aishell3/tts2/default_multi.yaml
examples/aishell3/tts2/default_multi.yaml
+52
-0
paddlespeech/t2s/datasets/am_batch_fn.py
paddlespeech/t2s/datasets/am_batch_fn.py
+41
-1
paddlespeech/t2s/exps/speedyspeech/normalize.py
paddlespeech/t2s/exps/speedyspeech/normalize.py
+10
-1
paddlespeech/t2s/exps/speedyspeech/preprocess.py
paddlespeech/t2s/exps/speedyspeech/preprocess.py
+27
-1
paddlespeech/t2s/exps/speedyspeech/train.py
paddlespeech/t2s/exps/speedyspeech/train.py
+28
-10
paddlespeech/t2s/models/speedyspeech/speedyspeech.py
paddlespeech/t2s/models/speedyspeech/speedyspeech.py
+73
-7
paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py
paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py
+11
-2
未找到文件。
examples/aishell3/tts2/default_multi.yaml
0 → 100644
浏览文件 @
11991b6d
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs
:
24000
# Sampling rate.
n_fft
:
2048
# FFT size (samples).
n_shift
:
300
# Hop size (samples). 12.5ms
win_length
:
1200
# Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window
:
"
hann"
# Window function.
n_mels
:
80
# Number of mel basis.
fmin
:
80
# Minimum freq in mel basis calculation.
fmax
:
7600
# Maximum frequency in mel basis calculation.
###########################################################
# DATA SETTING #
###########################################################
batch_size
:
32
num_workers
:
4
###########################################################
# MODEL SETTING #
###########################################################
model
:
encoder_hidden_size
:
128
encoder_kernel_size
:
3
encoder_dilations
:
[
1
,
3
,
9
,
27
,
1
,
3
,
9
,
27
,
1
,
1
]
duration_predictor_hidden_size
:
128
decoder_hidden_size
:
128
decoder_output_size
:
80
decoder_kernel_size
:
3
decoder_dilations
:
[
1
,
3
,
9
,
27
,
1
,
3
,
9
,
27
,
1
,
3
,
9
,
27
,
1
,
3
,
9
,
27
,
1
,
1
]
spk_embed_dim
:
256
spk_embed_integration_type
:
add
# speaker embedding integration type
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer
:
optim
:
adam
# optimizer type
learning_rate
:
0.002
# learning rate
max_grad_norm
:
1
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch
:
100
num_snapshots
:
5
###########################################################
# OTHER SETTING #
###########################################################
seed
:
10086
\ No newline at end of file
paddlespeech/t2s/datasets/am_batch_fn.py
浏览文件 @
11991b6d
...
@@ -17,7 +17,7 @@ import paddle
...
@@ -17,7 +17,7 @@ import paddle
from
paddlespeech.t2s.data.batch
import
batch_sequences
from
paddlespeech.t2s.data.batch
import
batch_sequences
def
speedyspeech_batch_fn
(
examples
):
def
speedyspeech_
single_spk_
batch_fn
(
examples
):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
phones
=
[
np
.
array
(
item
[
"phones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
phones
=
[
np
.
array
(
item
[
"phones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
tones
=
[
np
.
array
(
item
[
"tones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
tones
=
[
np
.
array
(
item
[
"tones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
...
@@ -54,6 +54,46 @@ def speedyspeech_batch_fn(examples):
...
@@ -54,6 +54,46 @@ def speedyspeech_batch_fn(examples):
}
}
return
batch
return
batch
def
speedyspeech_multi_spk_batch_fn
(
examples
):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
phones
=
[
np
.
array
(
item
[
"phones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
tones
=
[
np
.
array
(
item
[
"tones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
feats
=
[
np
.
array
(
item
[
"feats"
],
dtype
=
np
.
float32
)
for
item
in
examples
]
durations
=
[
np
.
array
(
item
[
"durations"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
num_phones
=
[
np
.
array
(
item
[
"num_phones"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
num_frames
=
[
np
.
array
(
item
[
"num_frames"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
phones
=
batch_sequences
(
phones
)
tones
=
batch_sequences
(
tones
)
feats
=
batch_sequences
(
feats
)
durations
=
batch_sequences
(
durations
)
# convert each batch to paddle.Tensor
phones
=
paddle
.
to_tensor
(
phones
)
tones
=
paddle
.
to_tensor
(
tones
)
feats
=
paddle
.
to_tensor
(
feats
)
durations
=
paddle
.
to_tensor
(
durations
)
num_phones
=
paddle
.
to_tensor
(
num_phones
)
num_frames
=
paddle
.
to_tensor
(
num_frames
)
batch
=
{
"phones"
:
phones
,
"tones"
:
tones
,
"num_phones"
:
num_phones
,
"num_frames"
:
num_frames
,
"feats"
:
feats
,
"durations"
:
durations
,
}
if
"spk_id"
in
examples
[
0
]:
spk_id
=
[
np
.
array
(
item
[
"spk_id"
],
dtype
=
np
.
int64
)
for
item
in
examples
]
spk_id
=
paddle
.
to_tensor
(
spk_id
)
batch
[
"spk_id"
]
=
spk_id
return
batch
def
fastspeech2_single_spk_batch_fn
(
examples
):
def
fastspeech2_single_spk_batch_fn
(
examples
):
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]
...
...
paddlespeech/t2s/exps/speedyspeech/normalize.py
浏览文件 @
11991b6d
...
@@ -47,7 +47,8 @@ def main():
...
@@ -47,7 +47,8 @@ def main():
"--phones-dict"
,
type
=
str
,
default
=
None
,
help
=
"phone vocabulary file."
)
"--phones-dict"
,
type
=
str
,
default
=
None
,
help
=
"phone vocabulary file."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--tones-dict"
,
type
=
str
,
default
=
None
,
help
=
"tone vocabulary file."
)
"--tones-dict"
,
type
=
str
,
default
=
None
,
help
=
"tone vocabulary file."
)
parser
.
add_argument
(
"--speaker-dict"
,
type
=
str
,
default
=
None
,
help
=
"speaker id map file."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--verbose"
,
"--verbose"
,
type
=
int
,
type
=
int
,
...
@@ -121,6 +122,12 @@ def main():
...
@@ -121,6 +122,12 @@ def main():
for
tone
,
id
in
tone_id
:
for
tone
,
id
in
tone_id
:
vocab_tones
[
tone
]
=
int
(
id
)
vocab_tones
[
tone
]
=
int
(
id
)
vocab_speaker
=
{}
with
open
(
args
.
speaker_dict
,
'rt'
)
as
f
:
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
for
spk
,
id
in
spk_id
:
vocab_speaker
[
spk
]
=
int
(
id
)
# process each file
# process each file
output_metadata
=
[]
output_metadata
=
[]
...
@@ -135,11 +142,13 @@ def main():
...
@@ -135,11 +142,13 @@ def main():
np
.
save
(
mel_path
,
mel
.
astype
(
np
.
float32
),
allow_pickle
=
False
)
np
.
save
(
mel_path
,
mel
.
astype
(
np
.
float32
),
allow_pickle
=
False
)
phone_ids
=
[
vocab_phones
[
p
]
for
p
in
item
[
'phones'
]]
phone_ids
=
[
vocab_phones
[
p
]
for
p
in
item
[
'phones'
]]
tone_ids
=
[
vocab_tones
[
p
]
for
p
in
item
[
'tones'
]]
tone_ids
=
[
vocab_tones
[
p
]
for
p
in
item
[
'tones'
]]
spk_id
=
vocab_speaker
[
item
[
"speaker"
]]
if
args
.
use_relative_path
:
if
args
.
use_relative_path
:
# convert absolute path to relative path:
# convert absolute path to relative path:
mel_path
=
mel_path
.
relative_to
(
dumpdir
)
mel_path
=
mel_path
.
relative_to
(
dumpdir
)
output_metadata
.
append
({
output_metadata
.
append
({
'utt_id'
:
utt_id
,
'utt_id'
:
utt_id
,
"spk_id"
:
spk_id
,
'phones'
:
phone_ids
,
'phones'
:
phone_ids
,
'tones'
:
tone_ids
,
'tones'
:
tone_ids
,
'num_phones'
:
item
[
'num_phones'
],
'num_phones'
:
item
[
'num_phones'
],
...
...
paddlespeech/t2s/exps/speedyspeech/preprocess.py
浏览文件 @
11991b6d
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
import
re
import
re
import
os
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
operator
import
itemgetter
from
operator
import
itemgetter
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -32,7 +33,7 @@ from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_
...
@@ -32,7 +33,7 @@ from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_
from
paddlespeech.t2s.datasets.preprocess_utils
import
get_phn_dur
from
paddlespeech.t2s.datasets.preprocess_utils
import
get_phn_dur
from
paddlespeech.t2s.datasets.preprocess_utils
import
get_phones_tones
from
paddlespeech.t2s.datasets.preprocess_utils
import
get_phones_tones
from
paddlespeech.t2s.datasets.preprocess_utils
import
merge_silence
from
paddlespeech.t2s.datasets.preprocess_utils
import
merge_silence
from
paddlespeech.t2s.datasets.preprocess_utils
import
get_spk_id_map
def
process_sentence
(
config
:
Dict
[
str
,
Any
],
def
process_sentence
(
config
:
Dict
[
str
,
Any
],
fp
:
Path
,
fp
:
Path
,
...
@@ -101,6 +102,7 @@ def process_sentence(config: Dict[str, Any],
...
@@ -101,6 +102,7 @@ def process_sentence(config: Dict[str, Any],
"utt_id"
:
utt_id
,
"utt_id"
:
utt_id
,
"phones"
:
phones
,
"phones"
:
phones
,
"tones"
:
tones
,
"tones"
:
tones
,
"speaker"
:
speaker
,
"num_phones"
:
len
(
phones
),
"num_phones"
:
len
(
phones
),
"num_frames"
:
num_frames
,
"num_frames"
:
num_frames
,
"durations"
:
durations
,
"durations"
:
durations
,
...
@@ -229,6 +231,8 @@ def main():
...
@@ -229,6 +231,8 @@ def main():
tone_id_map_path
=
dumpdir
/
"tone_id_map.txt"
tone_id_map_path
=
dumpdir
/
"tone_id_map.txt"
get_phones_tones
(
sentences
,
phone_id_map_path
,
tone_id_map_path
,
get_phones_tones
(
sentences
,
phone_id_map_path
,
tone_id_map_path
,
args
.
dataset
)
args
.
dataset
)
speaker_id_map_path
=
dumpdir
/
"speaker_id_map.txt"
get_spk_id_map
(
speaker_set
,
speaker_id_map_path
)
if
args
.
dataset
==
"baker"
:
if
args
.
dataset
==
"baker"
:
wav_files
=
sorted
(
list
((
rootdir
/
"Wave"
).
rglob
(
"*.wav"
)))
wav_files
=
sorted
(
list
((
rootdir
/
"Wave"
).
rglob
(
"*.wav"
)))
...
@@ -239,6 +243,28 @@ def main():
...
@@ -239,6 +243,28 @@ def main():
dev_wav_files
=
wav_files
[
num_train
:
num_train
+
num_dev
]
dev_wav_files
=
wav_files
[
num_train
:
num_train
+
num_dev
]
test_wav_files
=
wav_files
[
num_train
+
num_dev
:]
test_wav_files
=
wav_files
[
num_train
+
num_dev
:]
elif
args
.
dataset
==
"other"
:
sub_num_dev
=
100
wav_dir
=
rootdir
/
"wav"
train_wav_files
=
[]
dev_wav_files
=
[]
test_wav_files
=
[]
for
speaker
in
os
.
listdir
(
wav_dir
):
if
os
.
path
.
exists
(
os
.
path
.
join
(
wav_dir
,
speaker
,
"split"
)):
wav_files
=
sorted
(
list
((
wav_dir
/
speaker
/
"split"
).
rglob
(
"*.wav"
)))
else
:
wav_files
=
sorted
(
list
((
wav_dir
/
speaker
).
rglob
(
"*.wav"
)))
if
len
(
wav_files
)
>
100
:
train_wav_files
+=
wav_files
[:
-
sub_num_dev
*
2
]
dev_wav_files
+=
wav_files
[
-
sub_num_dev
*
2
:
-
sub_num_dev
]
test_wav_files
+=
wav_files
[
-
sub_num_dev
:]
else
:
train_wav_files
+=
wav_files
print
(
"len train_wav_files"
,
len
(
train_wav_files
))
print
(
"len dev_wav_files"
,
len
(
dev_wav_files
))
print
(
"len test_wav_files"
,
len
(
test_wav_files
))
train_dump_dir
=
dumpdir
/
"train"
/
"raw"
train_dump_dir
=
dumpdir
/
"train"
/
"raw"
train_dump_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
train_dump_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
dev_dump_dir
=
dumpdir
/
"dev"
/
"raw"
dev_dump_dir
=
dumpdir
/
"dev"
/
"raw"
...
...
paddlespeech/t2s/exps/speedyspeech/train.py
浏览文件 @
11991b6d
...
@@ -27,7 +27,8 @@ from paddle.io import DataLoader
...
@@ -27,7 +27,8 @@ from paddle.io import DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddle.io
import
DistributedBatchSampler
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
paddlespeech.t2s.datasets.am_batch_fn
import
speedyspeech_batch_fn
from
paddlespeech.t2s.datasets.am_batch_fn
import
speedyspeech_single_spk_batch_fn
from
paddlespeech.t2s.datasets.am_batch_fn
import
speedyspeech_multi_spk_batch_fn
from
paddlespeech.t2s.datasets.data_table
import
DataTable
from
paddlespeech.t2s.datasets.data_table
import
DataTable
from
paddlespeech.t2s.models.speedyspeech
import
SpeedySpeech
from
paddlespeech.t2s.models.speedyspeech
import
SpeedySpeech
from
paddlespeech.t2s.models.speedyspeech
import
SpeedySpeechEvaluator
from
paddlespeech.t2s.models.speedyspeech
import
SpeedySpeechEvaluator
...
@@ -57,6 +58,21 @@ def train_sp(args, config):
...
@@ -57,6 +58,21 @@ def train_sp(args, config):
f
"rank:
{
dist
.
get_rank
()
}
, pid:
{
os
.
getpid
()
}
, parent_pid:
{
os
.
getppid
()
}
"
,
f
"rank:
{
dist
.
get_rank
()
}
, pid:
{
os
.
getpid
()
}
, parent_pid:
{
os
.
getppid
()
}
"
,
)
)
fields
=
[
"phones"
,
"tones"
,
"num_phones"
,
"num_frames"
,
"feats"
,
"durations"
]
spk_num
=
None
if
args
.
speaker_dict
is
not
None
:
print
(
"multiple speaker speedyspeech!"
)
collate_fn
=
speedyspeech_multi_spk_batch_fn
with
open
(
args
.
speaker_dict
,
'rt'
)
as
f
:
spk_id
=
[
line
.
strip
().
split
()
for
line
in
f
.
readlines
()]
spk_num
=
len
(
spk_id
)
fields
+=
[
"spk_id"
]
else
:
print
(
"single speaker speedyspeech!"
)
collate_fn
=
speedyspeech_single_spk_batch_fn
print
(
"spk_num:"
,
spk_num
)
# dataloader has been too verbose
# dataloader has been too verbose
logging
.
getLogger
(
"DataLoader"
).
disabled
=
True
logging
.
getLogger
(
"DataLoader"
).
disabled
=
True
...
@@ -71,9 +87,7 @@ def train_sp(args, config):
...
@@ -71,9 +87,7 @@ def train_sp(args, config):
train_dataset
=
DataTable
(
train_dataset
=
DataTable
(
data
=
train_metadata
,
data
=
train_metadata
,
fields
=
[
fields
=
fields
,
"phones"
,
"tones"
,
"num_phones"
,
"num_frames"
,
"feats"
,
"durations"
],
converters
=
{
converters
=
{
"feats"
:
np
.
load
,
"feats"
:
np
.
load
,
},
)
},
)
...
@@ -87,9 +101,7 @@ def train_sp(args, config):
...
@@ -87,9 +101,7 @@ def train_sp(args, config):
dev_dataset
=
DataTable
(
dev_dataset
=
DataTable
(
data
=
dev_metadata
,
data
=
dev_metadata
,
fields
=
[
fields
=
fields
,
"phones"
,
"tones"
,
"num_phones"
,
"num_frames"
,
"feats"
,
"durations"
],
converters
=
{
converters
=
{
"feats"
:
np
.
load
,
"feats"
:
np
.
load
,
},
)
},
)
...
@@ -105,14 +117,14 @@ def train_sp(args, config):
...
@@ -105,14 +117,14 @@ def train_sp(args, config):
train_dataloader
=
DataLoader
(
train_dataloader
=
DataLoader
(
train_dataset
,
train_dataset
,
batch_sampler
=
train_sampler
,
batch_sampler
=
train_sampler
,
collate_fn
=
speedyspeech_batch
_fn
,
collate_fn
=
collate
_fn
,
num_workers
=
config
.
num_workers
)
num_workers
=
config
.
num_workers
)
dev_dataloader
=
DataLoader
(
dev_dataloader
=
DataLoader
(
dev_dataset
,
dev_dataset
,
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
,
drop_last
=
False
,
batch_size
=
config
.
batch_size
,
batch_size
=
config
.
batch_size
,
collate_fn
=
speedyspeech_batch
_fn
,
collate_fn
=
collate
_fn
,
num_workers
=
config
.
num_workers
)
num_workers
=
config
.
num_workers
)
print
(
"dataloaders done!"
)
print
(
"dataloaders done!"
)
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
with
open
(
args
.
phones_dict
,
"r"
)
as
f
:
...
@@ -125,7 +137,7 @@ def train_sp(args, config):
...
@@ -125,7 +137,7 @@ def train_sp(args, config):
print
(
"tone_size:"
,
tone_size
)
print
(
"tone_size:"
,
tone_size
)
model
=
SpeedySpeech
(
model
=
SpeedySpeech
(
vocab_size
=
vocab_size
,
tone_size
=
tone_size
,
**
config
[
"model"
])
vocab_size
=
vocab_size
,
tone_size
=
tone_size
,
spk_num
=
spk_num
,
**
config
[
"model"
])
if
world_size
>
1
:
if
world_size
>
1
:
model
=
DataParallel
(
model
)
model
=
DataParallel
(
model
)
print
(
"model done!"
)
print
(
"model done!"
)
...
@@ -184,6 +196,12 @@ def main():
...
@@ -184,6 +196,12 @@ def main():
parser
.
add_argument
(
parser
.
add_argument
(
"--tones-dict"
,
type
=
str
,
default
=
None
,
help
=
"tone vocabulary file."
)
"--tones-dict"
,
type
=
str
,
default
=
None
,
help
=
"tone vocabulary file."
)
parser
.
add_argument
(
"--speaker-dict"
,
type
=
str
,
default
=
None
,
help
=
"speaker id map file for multiple speaker model."
)
# 这里可以多传入 max_epoch 等
# 这里可以多传入 max_epoch 等
args
,
rest
=
parser
.
parse_known_args
()
args
,
rest
=
parser
.
parse_known_args
()
...
...
paddlespeech/t2s/models/speedyspeech/speedyspeech.py
浏览文件 @
11991b6d
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
import
paddle.nn.functional
as
F
from
paddlespeech.t2s.modules.positional_encoding
import
sinusoid_position_encoding
from
paddlespeech.t2s.modules.positional_encoding
import
sinusoid_position_encoding
...
@@ -171,7 +171,11 @@ class SpeedySpeech(nn.Layer):
...
@@ -171,7 +171,11 @@ class SpeedySpeech(nn.Layer):
decoder_output_size
,
decoder_output_size
,
decoder_kernel_size
,
decoder_kernel_size
,
decoder_dilations
,
decoder_dilations
,
tone_size
=
None
,
):
tone_size
=
None
,
spk_num
:
int
=
None
,
spk_embed_dim
:
int
=
None
,
spk_embed_integration_type
:
str
=
"add"
,
):
super
().
__init__
()
super
().
__init__
()
encoder
=
SpeedySpeechEncoder
(
vocab_size
,
tone_size
,
encoder
=
SpeedySpeechEncoder
(
vocab_size
,
tone_size
,
encoder_hidden_size
,
encoder_kernel_size
,
encoder_hidden_size
,
encoder_kernel_size
,
...
@@ -183,14 +187,43 @@ class SpeedySpeech(nn.Layer):
...
@@ -183,14 +187,43 @@ class SpeedySpeech(nn.Layer):
self
.
encoder
=
encoder
self
.
encoder
=
encoder
self
.
duration_predictor
=
duration_predictor
self
.
duration_predictor
=
duration_predictor
self
.
decoder
=
decoder
self
.
decoder
=
decoder
self
.
spk_embed_dim
=
spk_embed_dim
def
forward
(
self
,
text
,
tones
,
durations
):
# use idx 0 as padding idx
self
.
padding_idx
=
0
if
self
.
spk_embed_dim
is
not
None
:
self
.
spk_embed_integration_type
=
spk_embed_integration_type
if
spk_num
and
self
.
spk_embed_dim
:
self
.
spk_embedding_table
=
nn
.
Embedding
(
num_embeddings
=
spk_num
,
embedding_dim
=
self
.
spk_embed_dim
,
padding_idx
=
self
.
padding_idx
)
self
.
encoder_hidden_size
=
encoder_hidden_size
# define additional projection for speaker embedding
if
self
.
spk_embed_dim
is
not
None
:
print
(
"spk_embed_integration_type------------"
,
spk_embed_integration_type
)
if
self
.
spk_embed_integration_type
==
"add"
:
self
.
spk_projection
=
nn
.
Linear
(
self
.
spk_embed_dim
,
self
.
encoder_hidden_size
)
else
:
self
.
spk_projection
=
nn
.
Linear
(
self
.
encoder_hidden_size
+
self
.
spk_embed_dim
,
self
.
encoder_hidden_size
)
def
forward
(
self
,
text
,
tones
,
durations
,
spk_id
:
paddle
.
Tensor
=
None
):
# input of embedding must be int64
# input of embedding must be int64
text
=
paddle
.
cast
(
text
,
'int64'
)
text
=
paddle
.
cast
(
text
,
'int64'
)
tones
=
paddle
.
cast
(
tones
,
'int64'
)
tones
=
paddle
.
cast
(
tones
,
'int64'
)
if
spk_id
is
not
None
:
spk_id
=
paddle
.
cast
(
spk_id
,
'int64'
)
durations
=
paddle
.
cast
(
durations
,
'int64'
)
durations
=
paddle
.
cast
(
durations
,
'int64'
)
encodings
=
self
.
encoder
(
text
,
tones
)
encodings
=
self
.
encoder
(
text
,
tones
)
# (B, T)
# (B, T)
if
self
.
spk_embed_dim
is
not
None
:
if
spk_id
is
not
None
:
spk_emb
=
self
.
spk_embedding_table
(
spk_id
)
encodings
=
self
.
_integrate_with_spk_embed
(
encodings
,
spk_emb
)
pred_durations
=
self
.
duration_predictor
(
encodings
.
detach
())
pred_durations
=
self
.
duration_predictor
(
encodings
.
detach
())
# expand encodings
# expand encodings
...
@@ -204,7 +237,7 @@ class SpeedySpeech(nn.Layer):
...
@@ -204,7 +237,7 @@ class SpeedySpeech(nn.Layer):
decoded
=
self
.
decoder
(
encodings
)
decoded
=
self
.
decoder
(
encodings
)
return
decoded
,
pred_durations
return
decoded
,
pred_durations
def
inference
(
self
,
text
,
tones
=
None
):
def
inference
(
self
,
text
,
tones
=
None
,
spk_id
=
None
,
):
# text: [T]
# text: [T]
# tones: [T]
# tones: [T]
# input of embedding must be int64
# input of embedding must be int64
...
@@ -215,6 +248,11 @@ class SpeedySpeech(nn.Layer):
...
@@ -215,6 +248,11 @@ class SpeedySpeech(nn.Layer):
tones
=
tones
.
unsqueeze
(
0
)
tones
=
tones
.
unsqueeze
(
0
)
encodings
=
self
.
encoder
(
text
,
tones
)
encodings
=
self
.
encoder
(
text
,
tones
)
if
self
.
spk_embed_dim
is
not
None
:
if
spk_id
is
not
None
:
spk_emb
=
self
.
spk_embedding_table
(
spk_id
)
encodings
=
self
.
_integrate_with_spk_embed
(
encodings
,
spk_emb
)
pred_durations
=
self
.
duration_predictor
(
encodings
)
# (1, T)
pred_durations
=
self
.
duration_predictor
(
encodings
)
# (1, T)
durations_to_expand
=
paddle
.
round
(
pred_durations
.
exp
())
durations_to_expand
=
paddle
.
round
(
pred_durations
.
exp
())
durations_to_expand
=
(
durations_to_expand
).
astype
(
paddle
.
int64
)
durations_to_expand
=
(
durations_to_expand
).
astype
(
paddle
.
int64
)
...
@@ -240,6 +278,34 @@ class SpeedySpeech(nn.Layer):
...
@@ -240,6 +278,34 @@ class SpeedySpeech(nn.Layer):
decoded
=
self
.
decoder
(
encodings
)
decoded
=
self
.
decoder
(
encodings
)
return
decoded
[
0
]
return
decoded
[
0
]
def
_integrate_with_spk_embed
(
self
,
hs
,
spk_emb
):
"""Integrate speaker embedding with hidden states.
Parameters
----------
hs : Tensor
Batch of hidden state sequences (B, Tmax, adim).
spk_emb : Tensor
Batch of speaker embeddings (B, spk_embed_dim).
Returns
----------
Tensor
Batch of integrated hidden state sequences (B, Tmax, adim)
"""
if
self
.
spk_embed_integration_type
==
"add"
:
# apply projection and then add to hidden states
spk_emb
=
self
.
spk_projection
(
F
.
normalize
(
spk_emb
))
hs
=
hs
+
spk_emb
.
unsqueeze
(
1
)
elif
self
.
spk_embed_integration_type
==
"concat"
:
# concat hidden states with spk embeds and then apply projection
spk_emb
=
F
.
normalize
(
spk_emb
).
unsqueeze
(
1
).
expand
(
shape
=
[
-
1
,
hs
.
shape
[
1
],
-
1
])
hs
=
self
.
spk_projection
(
paddle
.
concat
([
hs
,
spk_emb
],
axis
=-
1
))
else
:
raise
NotImplementedError
(
"support only add or concat."
)
return
hs
class
SpeedySpeechInference
(
nn
.
Layer
):
class
SpeedySpeechInference
(
nn
.
Layer
):
def
__init__
(
self
,
normalizer
,
speedyspeech_model
):
def
__init__
(
self
,
normalizer
,
speedyspeech_model
):
...
@@ -247,7 +313,7 @@ class SpeedySpeechInference(nn.Layer):
...
@@ -247,7 +313,7 @@ class SpeedySpeechInference(nn.Layer):
self
.
normalizer
=
normalizer
self
.
normalizer
=
normalizer
self
.
acoustic_model
=
speedyspeech_model
self
.
acoustic_model
=
speedyspeech_model
def
forward
(
self
,
phones
,
tones
):
def
forward
(
self
,
phones
,
tones
,
spk_id
=
None
):
normalized_mel
=
self
.
acoustic_model
.
inference
(
phones
,
tones
)
normalized_mel
=
self
.
acoustic_model
.
inference
(
phones
,
tones
,
spk_id
)
logmel
=
self
.
normalizer
.
inverse
(
normalized_mel
)
logmel
=
self
.
normalizer
.
inverse
(
normalized_mel
)
return
logmel
return
logmel
paddlespeech/t2s/models/speedyspeech/speedyspeech_updater.py
浏览文件 @
11991b6d
...
@@ -50,10 +50,15 @@ class SpeedySpeechUpdater(StandardUpdater):
...
@@ -50,10 +50,15 @@ class SpeedySpeechUpdater(StandardUpdater):
self
.
msg
=
"Rank: {}, "
.
format
(
dist
.
get_rank
())
self
.
msg
=
"Rank: {}, "
.
format
(
dist
.
get_rank
())
losses_dict
=
{}
losses_dict
=
{}
# spk_id!=None in multiple spk speedyspeech
spk_id
=
batch
[
"spk_id"
]
if
"spk_id"
in
batch
else
None
decoded
,
predicted_durations
=
self
.
model
(
decoded
,
predicted_durations
=
self
.
model
(
text
=
batch
[
"phones"
],
text
=
batch
[
"phones"
],
tones
=
batch
[
"tones"
],
tones
=
batch
[
"tones"
],
durations
=
batch
[
"durations"
])
durations
=
batch
[
"durations"
],
spk_id
=
spk_id
)
target_mel
=
batch
[
"feats"
]
target_mel
=
batch
[
"feats"
]
spec_mask
=
F
.
sequence_mask
(
spec_mask
=
F
.
sequence_mask
(
...
@@ -112,10 +117,14 @@ class SpeedySpeechEvaluator(StandardEvaluator):
...
@@ -112,10 +117,14 @@ class SpeedySpeechEvaluator(StandardEvaluator):
self
.
msg
=
"Evaluate: "
self
.
msg
=
"Evaluate: "
losses_dict
=
{}
losses_dict
=
{}
spk_id
=
batch
[
"spk_id"
]
if
"spk_id"
in
batch
else
None
decoded
,
predicted_durations
=
self
.
model
(
decoded
,
predicted_durations
=
self
.
model
(
text
=
batch
[
"phones"
],
text
=
batch
[
"phones"
],
tones
=
batch
[
"tones"
],
tones
=
batch
[
"tones"
],
durations
=
batch
[
"durations"
])
durations
=
batch
[
"durations"
],
spk_id
=
spk_id
)
target_mel
=
batch
[
"feats"
]
target_mel
=
batch
[
"feats"
]
spec_mask
=
F
.
sequence_mask
(
spec_mask
=
F
.
sequence_mask
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录