未验证 提交 98ce69d0 编写于 作者: 小湉湉's avatar 小湉湉 提交者: GitHub

Merge pull request #1259 from jerryuhoo/develop

[TTS]Add multi-speaker support for the SpeedySpeech model
...@@ -45,6 +45,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then ...@@ -45,6 +45,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--stats=dump/train/feats_stats.npy \ --stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \ --phones-dict=dump/phone_id_map.txt \
--tones-dict=dump/tone_id_map.txt \ --tones-dict=dump/tone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \
--use-relative-path=True --use-relative-path=True
python3 ${BIN_DIR}/normalize.py \ python3 ${BIN_DIR}/normalize.py \
...@@ -53,6 +54,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then ...@@ -53,6 +54,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--stats=dump/train/feats_stats.npy \ --stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \ --phones-dict=dump/phone_id_map.txt \
--tones-dict=dump/tone_id_map.txt \ --tones-dict=dump/tone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \
--use-relative-path=True --use-relative-path=True
python3 ${BIN_DIR}/normalize.py \ python3 ${BIN_DIR}/normalize.py \
...@@ -61,6 +63,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then ...@@ -61,6 +63,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--stats=dump/train/feats_stats.npy \ --stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \ --phones-dict=dump/phone_id_map.txt \
--tones-dict=dump/tone_id_map.txt \ --tones-dict=dump/tone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \
--use-relative-path=True --use-relative-path=True
fi fi
...@@ -17,7 +17,7 @@ import paddle ...@@ -17,7 +17,7 @@ import paddle
from paddlespeech.t2s.data.batch import batch_sequences from paddlespeech.t2s.data.batch import batch_sequences
def speedyspeech_batch_fn(examples): def speedyspeech_single_spk_batch_fn(examples):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"] # fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
phones = [np.array(item["phones"], dtype=np.int64) for item in examples] phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
tones = [np.array(item["tones"], dtype=np.int64) for item in examples] tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
...@@ -54,6 +54,46 @@ def speedyspeech_batch_fn(examples): ...@@ -54,6 +54,46 @@ def speedyspeech_batch_fn(examples):
} }
return batch return batch
def speedyspeech_multi_spk_batch_fn(examples):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
durations = [
np.array(item["durations"], dtype=np.int64) for item in examples
]
num_phones = [
np.array(item["num_phones"], dtype=np.int64) for item in examples
]
num_frames = [
np.array(item["num_frames"], dtype=np.int64) for item in examples
]
phones = batch_sequences(phones)
tones = batch_sequences(tones)
feats = batch_sequences(feats)
durations = batch_sequences(durations)
# convert each batch to paddle.Tensor
phones = paddle.to_tensor(phones)
tones = paddle.to_tensor(tones)
feats = paddle.to_tensor(feats)
durations = paddle.to_tensor(durations)
num_phones = paddle.to_tensor(num_phones)
num_frames = paddle.to_tensor(num_frames)
batch = {
"phones": phones,
"tones": tones,
"num_phones": num_phones,
"num_frames": num_frames,
"feats": feats,
"durations": durations,
}
if "spk_id" in examples[0]:
spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples]
spk_id = paddle.to_tensor(spk_id)
batch["spk_id"] = spk_id
return batch
def fastspeech2_single_spk_batch_fn(examples): def fastspeech2_single_spk_batch_fn(examples):
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"] # fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]
......
...@@ -47,7 +47,8 @@ def main(): ...@@ -47,7 +47,8 @@ def main():
"--phones-dict", type=str, default=None, help="phone vocabulary file.") "--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument( parser.add_argument(
"--tones-dict", type=str, default=None, help="tone vocabulary file.") "--tones-dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument(
"--speaker-dict", type=str, default=None, help="speaker id map file.")
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",
type=int, type=int,
...@@ -121,6 +122,12 @@ def main(): ...@@ -121,6 +122,12 @@ def main():
for tone, id in tone_id: for tone, id in tone_id:
vocab_tones[tone] = int(id) vocab_tones[tone] = int(id)
vocab_speaker = {}
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
for spk, id in spk_id:
vocab_speaker[spk] = int(id)
# process each file # process each file
output_metadata = [] output_metadata = []
...@@ -135,11 +142,13 @@ def main(): ...@@ -135,11 +142,13 @@ def main():
np.save(mel_path, mel.astype(np.float32), allow_pickle=False) np.save(mel_path, mel.astype(np.float32), allow_pickle=False)
phone_ids = [vocab_phones[p] for p in item['phones']] phone_ids = [vocab_phones[p] for p in item['phones']]
tone_ids = [vocab_tones[p] for p in item['tones']] tone_ids = [vocab_tones[p] for p in item['tones']]
spk_id = vocab_speaker[item["speaker"]]
if args.use_relative_path: if args.use_relative_path:
# convert absolute path to relative path: # convert absolute path to relative path:
mel_path = mel_path.relative_to(dumpdir) mel_path = mel_path.relative_to(dumpdir)
output_metadata.append({ output_metadata.append({
'utt_id': utt_id, 'utt_id': utt_id,
"spk_id": spk_id,
'phones': phone_ids, 'phones': phone_ids,
'tones': tone_ids, 'tones': tone_ids,
'num_phones': item['num_phones'], 'num_phones': item['num_phones'],
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import re import re
import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from operator import itemgetter from operator import itemgetter
from pathlib import Path from pathlib import Path
...@@ -32,7 +33,7 @@ from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_ ...@@ -32,7 +33,7 @@ from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
from paddlespeech.t2s.datasets.preprocess_utils import get_phones_tones from paddlespeech.t2s.datasets.preprocess_utils import get_phones_tones
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
def process_sentence(config: Dict[str, Any], def process_sentence(config: Dict[str, Any],
fp: Path, fp: Path,
...@@ -101,6 +102,7 @@ def process_sentence(config: Dict[str, Any], ...@@ -101,6 +102,7 @@ def process_sentence(config: Dict[str, Any],
"utt_id": utt_id, "utt_id": utt_id,
"phones": phones, "phones": phones,
"tones": tones, "tones": tones,
"speaker": speaker,
"num_phones": len(phones), "num_phones": len(phones),
"num_frames": num_frames, "num_frames": num_frames,
"durations": durations, "durations": durations,
...@@ -229,6 +231,8 @@ def main(): ...@@ -229,6 +231,8 @@ def main():
tone_id_map_path = dumpdir / "tone_id_map.txt" tone_id_map_path = dumpdir / "tone_id_map.txt"
get_phones_tones(sentences, phone_id_map_path, tone_id_map_path, get_phones_tones(sentences, phone_id_map_path, tone_id_map_path,
args.dataset) args.dataset)
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
get_spk_id_map(speaker_set, speaker_id_map_path)
if args.dataset == "baker": if args.dataset == "baker":
wav_files = sorted(list((rootdir / "Wave").rglob("*.wav"))) wav_files = sorted(list((rootdir / "Wave").rglob("*.wav")))
......
...@@ -27,7 +27,8 @@ from paddle.io import DataLoader ...@@ -27,7 +27,8 @@ from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import speedyspeech_batch_fn from paddlespeech.t2s.datasets.am_batch_fn import speedyspeech_single_spk_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import speedyspeech_multi_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.speedyspeech import SpeedySpeech from paddlespeech.t2s.models.speedyspeech import SpeedySpeech
from paddlespeech.t2s.models.speedyspeech import SpeedySpeechEvaluator from paddlespeech.t2s.models.speedyspeech import SpeedySpeechEvaluator
...@@ -57,6 +58,21 @@ def train_sp(args, config): ...@@ -57,6 +58,21 @@ def train_sp(args, config):
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
) )
fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker speedyspeech!")
collate_fn = speedyspeech_multi_spk_batch_fn
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
fields += ["spk_id"]
else:
print("single speaker speedyspeech!")
collate_fn = speedyspeech_single_spk_batch_fn
print("spk_num:", spk_num)
# dataloader has been too verbose # dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True logging.getLogger("DataLoader").disabled = True
...@@ -71,9 +87,7 @@ def train_sp(args, config): ...@@ -71,9 +87,7 @@ def train_sp(args, config):
train_dataset = DataTable( train_dataset = DataTable(
data=train_metadata, data=train_metadata,
fields=[ fields=fields,
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
],
converters={ converters={
"feats": np.load, "feats": np.load,
}, ) }, )
...@@ -87,9 +101,7 @@ def train_sp(args, config): ...@@ -87,9 +101,7 @@ def train_sp(args, config):
dev_dataset = DataTable( dev_dataset = DataTable(
data=dev_metadata, data=dev_metadata,
fields=[ fields=fields,
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
],
converters={ converters={
"feats": np.load, "feats": np.load,
}, ) }, )
...@@ -105,14 +117,14 @@ def train_sp(args, config): ...@@ -105,14 +117,14 @@ def train_sp(args, config):
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_dataset, train_dataset,
batch_sampler=train_sampler, batch_sampler=train_sampler,
collate_fn=speedyspeech_batch_fn, collate_fn=collate_fn,
num_workers=config.num_workers) num_workers=config.num_workers)
dev_dataloader = DataLoader( dev_dataloader = DataLoader(
dev_dataset, dev_dataset,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
batch_size=config.batch_size, batch_size=config.batch_size,
collate_fn=speedyspeech_batch_fn, collate_fn=collate_fn,
num_workers=config.num_workers) num_workers=config.num_workers)
print("dataloaders done!") print("dataloaders done!")
with open(args.phones_dict, "r") as f: with open(args.phones_dict, "r") as f:
...@@ -125,7 +137,7 @@ def train_sp(args, config): ...@@ -125,7 +137,7 @@ def train_sp(args, config):
print("tone_size:", tone_size) print("tone_size:", tone_size)
model = SpeedySpeech( model = SpeedySpeech(
vocab_size=vocab_size, tone_size=tone_size, **config["model"]) vocab_size=vocab_size, tone_size=tone_size, spk_num=spk_num, **config["model"])
if world_size > 1: if world_size > 1:
model = DataParallel(model) model = DataParallel(model)
print("model done!") print("model done!")
...@@ -184,6 +196,12 @@ def main(): ...@@ -184,6 +196,12 @@ def main():
parser.add_argument( parser.add_argument(
"--tones-dict", type=str, default=None, help="tone vocabulary file.") "--tones-dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file for multiple speaker model.")
# 这里可以多传入 max_epoch 等 # 这里可以多传入 max_epoch 等
args, rest = parser.parse_known_args() args, rest = parser.parse_known_args()
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import numpy as np import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
import paddle.nn.functional as F
from paddlespeech.t2s.modules.positional_encoding import sinusoid_position_encoding from paddlespeech.t2s.modules.positional_encoding import sinusoid_position_encoding
...@@ -96,7 +96,7 @@ class TextEmbedding(nn.Layer): ...@@ -96,7 +96,7 @@ class TextEmbedding(nn.Layer):
class SpeedySpeechEncoder(nn.Layer): class SpeedySpeechEncoder(nn.Layer):
def __init__(self, vocab_size, tone_size, hidden_size, kernel_size, def __init__(self, vocab_size, tone_size, hidden_size, kernel_size,
dilations): dilations, spk_num=None):
super().__init__() super().__init__()
self.embedding = TextEmbedding( self.embedding = TextEmbedding(
vocab_size, vocab_size,
...@@ -104,6 +104,15 @@ class SpeedySpeechEncoder(nn.Layer): ...@@ -104,6 +104,15 @@ class SpeedySpeechEncoder(nn.Layer):
tone_size, tone_size,
padding_idx=0, padding_idx=0,
tone_padding_idx=0) tone_padding_idx=0)
if spk_num:
self.spk_emb = nn.Embedding(
num_embeddings=spk_num,
embedding_dim=hidden_size,
padding_idx=0)
else:
self.spk_emb = None
self.prenet = nn.Sequential( self.prenet = nn.Sequential(
nn.Linear(hidden_size, hidden_size), nn.Linear(hidden_size, hidden_size),
nn.ReLU(), ) nn.ReLU(), )
...@@ -118,8 +127,10 @@ class SpeedySpeechEncoder(nn.Layer): ...@@ -118,8 +127,10 @@ class SpeedySpeechEncoder(nn.Layer):
nn.BatchNorm1D(hidden_size, data_format="NLC"), nn.BatchNorm1D(hidden_size, data_format="NLC"),
nn.Linear(hidden_size, hidden_size), ) nn.Linear(hidden_size, hidden_size), )
def forward(self, text, tones): def forward(self, text, tones, spk_id=None):
embedding = self.embedding(text, tones) embedding = self.embedding(text, tones)
if self.spk_emb:
embedding += self.spk_emb(spk_id).unsqueeze(1)
embedding = self.prenet(embedding) embedding = self.prenet(embedding)
x = self.res_blocks(embedding) x = self.res_blocks(embedding)
x = embedding + self.postnet1(x) x = embedding + self.postnet1(x)
...@@ -171,11 +182,12 @@ class SpeedySpeech(nn.Layer): ...@@ -171,11 +182,12 @@ class SpeedySpeech(nn.Layer):
decoder_output_size, decoder_output_size,
decoder_kernel_size, decoder_kernel_size,
decoder_dilations, decoder_dilations,
tone_size=None, ): tone_size=None,
spk_num=None):
super().__init__() super().__init__()
encoder = SpeedySpeechEncoder(vocab_size, tone_size, encoder = SpeedySpeechEncoder(vocab_size, tone_size,
encoder_hidden_size, encoder_kernel_size, encoder_hidden_size, encoder_kernel_size,
encoder_dilations) encoder_dilations, spk_num)
duration_predictor = DurationPredictor(duration_predictor_hidden_size) duration_predictor = DurationPredictor(duration_predictor_hidden_size)
decoder = SpeedySpeechDecoder(decoder_hidden_size, decoder_output_size, decoder = SpeedySpeechDecoder(decoder_hidden_size, decoder_output_size,
decoder_kernel_size, decoder_dilations) decoder_kernel_size, decoder_dilations)
...@@ -184,13 +196,15 @@ class SpeedySpeech(nn.Layer): ...@@ -184,13 +196,15 @@ class SpeedySpeech(nn.Layer):
self.duration_predictor = duration_predictor self.duration_predictor = duration_predictor
self.decoder = decoder self.decoder = decoder
def forward(self, text, tones, durations): def forward(self, text, tones, durations, spk_id: paddle.Tensor=None):
# input of embedding must be int64 # input of embedding must be int64
text = paddle.cast(text, 'int64') text = paddle.cast(text, 'int64')
tones = paddle.cast(tones, 'int64') tones = paddle.cast(tones, 'int64')
if spk_id is not None:
spk_id = paddle.cast(spk_id, 'int64')
durations = paddle.cast(durations, 'int64') durations = paddle.cast(durations, 'int64')
encodings = self.encoder(text, tones) encodings = self.encoder(text, tones, spk_id)
# (B, T)
pred_durations = self.duration_predictor(encodings.detach()) pred_durations = self.duration_predictor(encodings.detach())
# expand encodings # expand encodings
...@@ -204,7 +218,7 @@ class SpeedySpeech(nn.Layer): ...@@ -204,7 +218,7 @@ class SpeedySpeech(nn.Layer):
decoded = self.decoder(encodings) decoded = self.decoder(encodings)
return decoded, pred_durations return decoded, pred_durations
def inference(self, text, tones=None): def inference(self, text, tones=None, spk_id=None):
# text: [T] # text: [T]
# tones: [T] # tones: [T]
# input of embedding must be int64 # input of embedding must be int64
...@@ -214,7 +228,8 @@ class SpeedySpeech(nn.Layer): ...@@ -214,7 +228,8 @@ class SpeedySpeech(nn.Layer):
tones = paddle.cast(tones, 'int64') tones = paddle.cast(tones, 'int64')
tones = tones.unsqueeze(0) tones = tones.unsqueeze(0)
encodings = self.encoder(text, tones) encodings = self.encoder(text, tones, spk_id)
pred_durations = self.duration_predictor(encodings) # (1, T) pred_durations = self.duration_predictor(encodings) # (1, T)
durations_to_expand = paddle.round(pred_durations.exp()) durations_to_expand = paddle.round(pred_durations.exp())
durations_to_expand = (durations_to_expand).astype(paddle.int64) durations_to_expand = (durations_to_expand).astype(paddle.int64)
...@@ -240,14 +255,13 @@ class SpeedySpeech(nn.Layer): ...@@ -240,14 +255,13 @@ class SpeedySpeech(nn.Layer):
decoded = self.decoder(encodings) decoded = self.decoder(encodings)
return decoded[0] return decoded[0]
class SpeedySpeechInference(nn.Layer): class SpeedySpeechInference(nn.Layer):
def __init__(self, normalizer, speedyspeech_model): def __init__(self, normalizer, speedyspeech_model):
super().__init__() super().__init__()
self.normalizer = normalizer self.normalizer = normalizer
self.acoustic_model = speedyspeech_model self.acoustic_model = speedyspeech_model
def forward(self, phones, tones): def forward(self, phones, tones, spk_id=None):
normalized_mel = self.acoustic_model.inference(phones, tones) normalized_mel = self.acoustic_model.inference(phones, tones, spk_id)
logmel = self.normalizer.inverse(normalized_mel) logmel = self.normalizer.inverse(normalized_mel)
return logmel return logmel
...@@ -50,10 +50,15 @@ class SpeedySpeechUpdater(StandardUpdater): ...@@ -50,10 +50,15 @@ class SpeedySpeechUpdater(StandardUpdater):
self.msg = "Rank: {}, ".format(dist.get_rank()) self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {} losses_dict = {}
# spk_id!=None in multiple spk speedyspeech
spk_id = batch["spk_id"] if "spk_id" in batch else None
decoded, predicted_durations = self.model( decoded, predicted_durations = self.model(
text=batch["phones"], text=batch["phones"],
tones=batch["tones"], tones=batch["tones"],
durations=batch["durations"]) durations=batch["durations"],
spk_id=spk_id
)
target_mel = batch["feats"] target_mel = batch["feats"]
spec_mask = F.sequence_mask( spec_mask = F.sequence_mask(
...@@ -112,10 +117,14 @@ class SpeedySpeechEvaluator(StandardEvaluator): ...@@ -112,10 +117,14 @@ class SpeedySpeechEvaluator(StandardEvaluator):
self.msg = "Evaluate: " self.msg = "Evaluate: "
losses_dict = {} losses_dict = {}
spk_id = batch["spk_id"] if "spk_id" in batch else None
decoded, predicted_durations = self.model( decoded, predicted_durations = self.model(
text=batch["phones"], text=batch["phones"],
tones=batch["tones"], tones=batch["tones"],
durations=batch["durations"]) durations=batch["durations"],
spk_id=spk_id
)
target_mel = batch["feats"] target_mel = batch["feats"]
spec_mask = F.sequence_mask( spec_mask = F.sequence_mask(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册