未验证 提交 e1817022 编写于 作者: L liangym 提交者: GitHub

[tts] add adversarial loss (#2588)

上级 9aab706c
......@@ -116,6 +116,8 @@ optional arguments:
5. `--phones-dict` is the path of the phone vocabulary file.
6. `--speaker-dict` is the path of the speaker id map file when training a multi-speaker FastSpeech2.
We have **added module speaker classifier** with reference to [Learning to Speak Fluently in a Foreign Language: Multilingual Speech Synthesis and Cross-Language Voice Cloning](https://arxiv.org/pdf/1907.04448.pdf). The main parameter configuration: `config["model"]["enable_speaker_classifier"]`, `config["model"]["hidden_sc_dim"]` and `config["updater"]["spk_loss_scale"]` in `conf/default.yaml`. The current experimental results show that this module can decouple text information and speaker information, and more experiments are still being sorted out. This module is currently not enabled by default, if you are interested, you can try it yourself.
### Synthesizing
We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1) as the default neural vocoder.
......
......@@ -74,6 +74,9 @@ model:
stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder
spk_embed_dim: 256 # speaker embedding dimension
spk_embed_integration_type: concat # speaker embedding integration type
enable_speaker_classifier: False # Whether to use speaker classifier module
hidden_sc_dim: 256 # The hidden layer dim of speaker classifier
......@@ -82,6 +85,7 @@ model:
###########################################################
updater:
use_masking: True # whether to apply masking for padded part in loss calculation
spk_loss_scale: 0.02 # The scales of speaker classifier loss
###########################################################
......
......@@ -145,17 +145,27 @@ def train_sp(args, config):
# copy conf to output_dir
shutil.copyfile(args.config, output_dir / config_name)
if "enable_speaker_classifier" in config.model:
enable_spk_cls = config.model.enable_speaker_classifier
else:
enable_spk_cls = False
updater = FastSpeech2Updater(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
output_dir=output_dir,
**config["updater"])
enable_spk_cls=enable_spk_cls,
**config["updater"], )
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = FastSpeech2Evaluator(
model, dev_dataloader, output_dir=output_dir, **config["updater"])
model,
dev_dataloader,
output_dir=output_dir,
enable_spk_cls=enable_spk_cls,
**config["updater"], )
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
......
......@@ -25,6 +25,8 @@ import paddle.nn.functional as F
from paddle import nn
from typeguard import check_argument_types
from paddlespeech.t2s.modules.adversarial_loss.gradient_reversal import GradientReversalLayer
from paddlespeech.t2s.modules.adversarial_loss.speaker_classifier import SpeakerClassifier
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
......@@ -138,7 +140,10 @@ class FastSpeech2(nn.Layer):
# training related
init_type: str="xavier_uniform",
init_enc_alpha: float=1.0,
init_dec_alpha: float=1.0, ):
init_dec_alpha: float=1.0,
# speaker classifier
enable_speaker_classifier: bool=False,
hidden_sc_dim: int=256, ):
"""Initialize FastSpeech2 module.
Args:
idim (int):
......@@ -268,6 +273,10 @@ class FastSpeech2(nn.Layer):
Initial value of alpha in scaled pos encoding of the encoder.
init_dec_alpha (float):
Initial value of alpha in scaled pos encoding of the decoder.
enable_speaker_classifier (bool):
Whether to use speaker classifier module
hidden_sc_dim (int):
The hidden layer dim of speaker classifier
"""
assert check_argument_types()
......@@ -281,6 +290,9 @@ class FastSpeech2(nn.Layer):
self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
self.use_scaled_pos_enc = use_scaled_pos_enc
self.hidden_sc_dim = hidden_sc_dim
self.spk_num = spk_num
self.enable_speaker_classifier = enable_speaker_classifier
self.spk_embed_dim = spk_embed_dim
if self.spk_embed_dim is not None:
......@@ -373,6 +385,12 @@ class FastSpeech2(nn.Layer):
self.tone_projection = nn.Linear(adim + self.tone_embed_dim,
adim)
if self.spk_num and self.enable_speaker_classifier:
# set lambda = 1
self.grad_reverse = GradientReversalLayer(1)
self.speaker_classifier = SpeakerClassifier(
idim=adim, hidden_sc_dim=self.hidden_sc_dim, spk_num=spk_num)
# define duration predictor
self.duration_predictor = DurationPredictor(
idim=adim,
......@@ -547,7 +565,7 @@ class FastSpeech2(nn.Layer):
if tone_id is not None:
tone_id = paddle.cast(tone_id, 'int64')
# forward propagation
before_outs, after_outs, d_outs, p_outs, e_outs = self._forward(
before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward(
xs,
ilens,
olens,
......@@ -564,7 +582,7 @@ class FastSpeech2(nn.Layer):
max_olen = max(olens)
ys = ys[:, :max_olen]
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits
def _forward(self,
xs: paddle.Tensor,
......@@ -584,6 +602,12 @@ class FastSpeech2(nn.Layer):
# (B, Tmax, adim)
hs, _ = self.encoder(xs, x_masks)
if self.spk_num and self.enable_speaker_classifier and not is_inference:
hs_for_spk_cls = self.grad_reverse(hs)
spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens)
else:
spk_logits = None
# integrate speaker embedding
if self.spk_embed_dim is not None:
# spk_emb has a higher priority than spk_id
......@@ -676,7 +700,7 @@ class FastSpeech2(nn.Layer):
after_outs = before_outs + self.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
return before_outs, after_outs, d_outs, p_outs, e_outs
return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits
def encoder_infer(
self,
......@@ -771,7 +795,7 @@ class FastSpeech2(nn.Layer):
es = e.unsqueeze(0) if e is not None else None
# (1, L, odim)
_, outs, d_outs, p_outs, e_outs = self._forward(
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs,
ilens,
ds=ds,
......@@ -783,7 +807,7 @@ class FastSpeech2(nn.Layer):
is_inference=True)
else:
# (1, L, odim)
_, outs, d_outs, p_outs, e_outs = self._forward(
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs,
ilens,
is_inference=True,
......@@ -791,6 +815,7 @@ class FastSpeech2(nn.Layer):
spk_emb=spk_emb,
spk_id=spk_id,
tone_id=tone_id)
return outs[0], d_outs[0], p_outs[0], e_outs[0]
def _integrate_with_spk_embed(self, hs, spk_emb):
......@@ -1058,6 +1083,7 @@ class FastSpeech2Loss(nn.Layer):
self.l1_criterion = nn.L1Loss(reduction=reduction)
self.mse_criterion = nn.MSELoss(reduction=reduction)
self.duration_criterion = DurationPredictorLoss(reduction=reduction)
self.ce_criterion = nn.CrossEntropyLoss()
def forward(
self,
......@@ -1072,7 +1098,10 @@ class FastSpeech2Loss(nn.Layer):
es: paddle.Tensor,
ilens: paddle.Tensor,
olens: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
spk_logits: paddle.Tensor=None,
spk_ids: paddle.Tensor=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
paddle.Tensor, ]:
"""Calculate forward propagation.
Args:
......@@ -1098,11 +1127,18 @@ class FastSpeech2Loss(nn.Layer):
Batch of the lengths of each input (B,).
olens(Tensor):
Batch of the lengths of each target (B,).
spk_logits(Option[Tensor]):
Batch of outputs after speaker classifier (B, Lmax, num_spk)
spk_ids(Option[Tensor]):
Batch of target spk_id (B,)
Returns:
"""
speaker_loss = 0.0
# apply mask to remove padded part
if self.use_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
......@@ -1124,6 +1160,16 @@ class FastSpeech2Loss(nn.Layer):
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
es = es.masked_select(pitch_masks.broadcast_to(es.shape))
if spk_logits is not None and spk_ids is not None:
batch_size = spk_ids.shape[0]
spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1],
None)
spk_logits = paddle.reshape(spk_logits,
[-1, spk_logits.shape[-1]])
mask_index = spk_logits.abs().sum(axis=1) != 0
spk_ids = spk_ids[mask_index]
spk_logits = spk_logits[mask_index]
# calculate loss
l1_loss = self.l1_criterion(before_outs, ys)
if after_outs is not None:
......@@ -1132,6 +1178,9 @@ class FastSpeech2Loss(nn.Layer):
pitch_loss = self.mse_criterion(p_outs, ps)
energy_loss = self.mse_criterion(e_outs, es)
if spk_logits is not None and spk_ids is not None:
speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size
# make weighted mask and apply it
if self.use_weighted_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
......@@ -1161,4 +1210,4 @@ class FastSpeech2Loss(nn.Layer):
energy_loss = energy_loss.masked_select(
pitch_masks.broadcast_to(energy_loss.shape)).sum()
return l1_loss, duration_loss, pitch_loss, energy_loss
return l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss
......@@ -14,6 +14,7 @@
import logging
from pathlib import Path
from paddle import DataParallel
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
......@@ -23,6 +24,7 @@ from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
......@@ -31,24 +33,30 @@ logger.setLevel(logging.INFO)
class FastSpeech2Updater(StandardUpdater):
def __init__(self,
model: Layer,
optimizer: Optimizer,
dataloader: DataLoader,
init_state=None,
use_masking: bool=False,
use_weighted_masking: bool=False,
output_dir: Path=None):
def __init__(
self,
model: Layer,
optimizer: Optimizer,
dataloader: DataLoader,
init_state=None,
use_masking: bool=False,
spk_loss_scale: float=0.02,
use_weighted_masking: bool=False,
output_dir: Path=None,
enable_spk_cls: bool=False, ):
super().__init__(model, optimizer, dataloader, init_state=None)
self.criterion = FastSpeech2Loss(
use_masking=use_masking, use_weighted_masking=use_weighted_masking)
use_masking=use_masking,
use_weighted_masking=use_weighted_masking, )
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
self.spk_loss_scale = spk_loss_scale
self.enable_spk_cls = enable_spk_cls
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
......@@ -60,18 +68,33 @@ class FastSpeech2Updater(StandardUpdater):
if spk_emb is not None:
spk_id = None
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb)
l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(
if type(
self.model
) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier:
with self.model.no_sync():
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb)
else:
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb)
l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion(
after_outs=after_outs,
before_outs=before_outs,
d_outs=d_outs,
......@@ -82,9 +105,12 @@ class FastSpeech2Updater(StandardUpdater):
ps=batch["pitch"],
es=batch["energy"],
ilens=batch["text_lengths"],
olens=olens)
olens=olens,
spk_logits=spk_logits,
spk_ids=spk_id, )
loss = l1_loss + duration_loss + pitch_loss + energy_loss
scaled_speaker_loss = self.spk_loss_scale * speaker_loss
loss = l1_loss + duration_loss + pitch_loss + energy_loss + scaled_speaker_loss
optimizer = self.optimizer
optimizer.clear_grad()
......@@ -96,11 +122,18 @@ class FastSpeech2Updater(StandardUpdater):
report("train/duration_loss", float(duration_loss))
report("train/pitch_loss", float(pitch_loss))
report("train/energy_loss", float(energy_loss))
if self.enable_spk_cls:
report("train/speaker_loss", float(speaker_loss))
report("train/scaled_speaker_loss", float(scaled_speaker_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["energy_loss"] = float(energy_loss)
if self.enable_spk_cls:
losses_dict["speaker_loss"] = float(speaker_loss)
losses_dict["scaled_speaker_loss"] = float(scaled_speaker_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
......@@ -112,7 +145,9 @@ class FastSpeech2Evaluator(StandardEvaluator):
dataloader: DataLoader,
use_masking: bool=False,
use_weighted_masking: bool=False,
output_dir: Path=None):
spk_loss_scale: float=0.02,
output_dir: Path=None,
enable_spk_cls: bool=False):
super().__init__(model, dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
......@@ -120,6 +155,8 @@ class FastSpeech2Evaluator(StandardEvaluator):
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
self.spk_loss_scale = spk_loss_scale
self.enable_spk_cls = enable_spk_cls
self.criterion = FastSpeech2Loss(
use_masking=use_masking, use_weighted_masking=use_weighted_masking)
......@@ -133,18 +170,33 @@ class FastSpeech2Evaluator(StandardEvaluator):
if spk_emb is not None:
spk_id = None
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb)
l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(
if type(
self.model
) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier:
with self.model.no_sync():
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb)
else:
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb)
l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion(
after_outs=after_outs,
before_outs=before_outs,
d_outs=d_outs,
......@@ -155,19 +207,29 @@ class FastSpeech2Evaluator(StandardEvaluator):
ps=batch["pitch"],
es=batch["energy"],
ilens=batch["text_lengths"],
olens=olens, )
loss = l1_loss + duration_loss + pitch_loss + energy_loss
olens=olens,
spk_logits=spk_logits,
spk_ids=spk_id, )
scaled_speaker_loss = self.spk_loss_scale * speaker_loss
loss = l1_loss + duration_loss + pitch_loss + energy_loss + scaled_speaker_loss
report("eval/loss", float(loss))
report("eval/l1_loss", float(l1_loss))
report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss))
report("eval/energy_loss", float(energy_loss))
if self.enable_spk_cls:
report("train/speaker_loss", float(speaker_loss))
report("train/scaled_speaker_loss", float(scaled_speaker_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
if self.enable_spk_cls:
losses_dict["speaker_loss"] = float(speaker_loss)
losses_dict["scaled_speaker_loss"] = float(scaled_speaker_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from paddle.autograd import PyLayer
class GradientReversalFunction(PyLayer):
"""Gradient Reversal Layer from:
Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
Forward pass is the identity function. In the backward pass,
the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
"""
@staticmethod
def forward(ctx, x, lambda_=1):
"""Forward in networks
"""
ctx.save_for_backward(lambda_)
return x.clone()
@staticmethod
def backward(ctx, grads):
"""Backward in networks
"""
lambda_, = ctx.saved_tensor()
dx = -lambda_ * grads
return paddle.clip(dx, min=-0.5, max=0.5)
class GradientReversalLayer(nn.Layer):
"""Gradient Reversal Layer from:
Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
Forward pass is the identity function. In the backward pass,
the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
"""
def __init__(self, lambda_=1):
super(GradientReversalLayer, self).__init__()
self.lambda_ = lambda_
def forward(self, x):
"""Forward in networks
"""
return GradientReversalFunction.apply(x, self.lambda_)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from Cross-Lingual-Voice-Cloning(https://github.com/deterministic-algorithms-lab/Cross-Lingual-Voice-Cloning)
import paddle
from paddle import nn
from typeguard import check_argument_types
class SpeakerClassifier(nn.Layer):
def __init__(
self,
idim: int,
hidden_sc_dim: int,
spk_num: int, ):
assert check_argument_types()
super().__init__()
# store hyperparameters
self.idim = idim
self.hidden_sc_dim = hidden_sc_dim
self.spk_num = spk_num
self.model = nn.Sequential(
nn.Linear(self.idim, self.hidden_sc_dim),
nn.Linear(self.hidden_sc_dim, self.spk_num))
def parse_outputs(self, out, text_lengths):
mask = paddle.arange(out.shape[1]).expand(
[out.shape[0], out.shape[1]]) < text_lengths.unsqueeze(1)
out = paddle.transpose(out, perm=[2, 0, 1])
out = out * mask
out = paddle.transpose(out, perm=[1, 2, 0])
return out
def forward(self, encoder_outputs, text_lengths):
"""
encoder_outputs = [batch_size, seq_len, encoder_embedding_size]
text_lengths = [batch_size]
log probabilities of speaker classification = [batch_size, seq_len, spk_num]
"""
out = self.model(encoder_outputs)
out = self.parse_outputs(out, text_lengths)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册