loss.py 13.4 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from __future__ import division
C
chenfeiyu 已提交
16 17 18 19 20 21 22 23 24 25 26
import numpy as np
from numba import jit

from paddle import fluid
import paddle.fluid.layers as F
import paddle.fluid.dygraph as dg


def masked_mean(inputs, mask):
    """
    Args:
C
chenfeiyu 已提交
27 28
        inputs (Variable): shape(B, T, C), dtype float32, the input.
        mask (Variable): shape(B, T), dtype float32, a mask. 
C
chenfeiyu 已提交
29
    Returns:
C
chenfeiyu 已提交
30
        loss (Variable): shape(1, ), dtype float32, masked mean.
C
chenfeiyu 已提交
31 32 33 34 35 36 37 38 39
    """
    channels = inputs.shape[-1]
    masked_inputs = F.elementwise_mul(inputs, mask, axis=0)
    loss = F.reduce_sum(masked_inputs) / (channels * F.reduce_sum(mask))
    return loss


@jit(nopython=True)
def guided_attention(N, max_N, T, max_T, g):
40 41 42 43 44 45 46 47 48 49
    """Generate an diagonal attention guide.
    
    Args:
        N (int): valid length of encoder.
        max_N (int): max length of encoder.
        T (int): valid length of decoder.
        max_T (int): max length of decoder.
        g (float): sigma to adjust the degree of diagonal guide.

    Returns:
C
chenfeiyu 已提交
50
        np.ndarray: shape(max_N, max_T), dtype float32, the diagonal guide.
51
    """
C
chenfeiyu 已提交
52 53 54 55 56 57 58
    W = np.zeros((max_N, max_T), dtype=np.float32)
    for n in range(N):
        for t in range(T):
            W[n, t] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
    return W


L
lifuchen 已提交
59
def guided_attentions(encoder_lengths, decoder_lengths, max_decoder_len,
C
chenfeiyu 已提交
60
                      g=0.2):
61 62 63 64 65 66 67 68 69
    """Generate a diagonal attention guide for a batch.

    Args:
        encoder_lengths (np.ndarray): shape(B, ), dtype: int64, encoder valid lengths.
        decoder_lengths (np.ndarray): shape(B, ), dtype: int64, decoder valid lengths.
        max_decoder_len (int): max length of decoder.
        g (float, optional): sigma to adjust the degree of diagonal guide.. Defaults to 0.2.

    Returns:
C
chenfeiyu 已提交
70
        np.ndarray: shape(B, max_T, max_N), dtype float32, the diagonal guide. (max_N: max encoder length, max_T: max decoder length.)
71
    """
C
chenfeiyu 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    B = len(encoder_lengths)
    max_input_len = encoder_lengths.max()
    W = np.zeros((B, max_decoder_len, max_input_len), dtype=np.float32)
    for b in range(B):
        W[b] = guided_attention(encoder_lengths[b], max_input_len,
                                decoder_lengths[b], max_decoder_len, g).T
    return W


class TTSLoss(object):
    def __init__(self,
                 masked_weight=0.0,
                 priority_bin=None,
                 priority_weight=0.0,
                 binary_divergence_weight=0.0,
                 guided_attention_sigma=0.2,
                 downsample_factor=4,
                 r=1):
90 91 92 93 94 95 96 97 98 99 100
        """Compute loss for Deep Voice 3 model.

        Args:
            masked_weight (float, optional): the weight of masked loss. Defaults to 0.0.
            priority_bin ([type], optional): frequency bands for linear spectrogram loss to be prioritized. Defaults to None.
            priority_weight (float, optional): weight for the prioritized frequency bands. Defaults to 0.0.
            binary_divergence_weight (float, optional): weight for binary cross entropy (used for spectrogram loss). Defaults to 0.0.
            guided_attention_sigma (float, optional): `sigma` for attention guide. Defaults to 0.2.
            downsample_factor (int, optional): the downsample factor for mel spectrogram. Defaults to 4.
            r (int, optional): frames per decoder step. Defaults to 1.
        """
C
chenfeiyu 已提交
101 102 103 104 105 106 107 108 109 110 111
        self.masked_weight = masked_weight
        self.priority_bin = priority_bin  # only used for lin-spec loss
        self.priority_weight = priority_weight  # only used for lin-spec loss
        self.binary_divergence_weight = binary_divergence_weight
        self.guided_attention_sigma = guided_attention_sigma

        self.time_shift = r
        self.r = r
        self.downsample_factor = downsample_factor

    def l1_loss(self, prediction, target, mask, priority_bin=None):
112 113 114
        """L1 loss for spectrogram.

        Args:
C
chenfeiyu 已提交
115 116
            prediction (Variable): shape(B, T, C), dtype float32, predicted spectrogram.
            target (Variable): shape(B, T, C), dtype float32, target spectrogram.
117 118 119 120
            mask (Variable): shape(B, T), mask.
            priority_bin (int, optional): frequency bands for linear spectrogram loss to be prioritized. Defaults to None.

        Returns:
C
chenfeiyu 已提交
121
            Variable: shape(1,), dtype float32, l1 loss(with mask and possibly priority bin applied.)
122
        """
C
chenfeiyu 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
        abs_diff = F.abs(prediction - target)

        # basic mask-weighted l1 loss
        w = self.masked_weight
        if w > 0 and mask is not None:
            base_l1_loss = w * masked_mean(abs_diff, mask) \
                         + (1 - w) * F.reduce_mean(abs_diff)
        else:
            base_l1_loss = F.reduce_mean(abs_diff)

        if self.priority_weight > 0 and priority_bin is not None:
            # mask-weighted priority channels' l1-loss
            priority_abs_diff = abs_diff[:, :, :priority_bin]
            if w > 0 and mask is not None:
                priority_loss = w * masked_mean(priority_abs_diff, mask) \
                              + (1 - w) * F.reduce_mean(priority_abs_diff)
            else:
                priority_loss = F.reduce_mean(priority_abs_diff)

            # priority weighted sum
            p = self.priority_weight
            loss = p * priority_loss + (1 - p) * base_l1_loss
        else:
            loss = base_l1_loss
        return loss

    def binary_divergence(self, prediction, target, mask):
150 151 152
        """Binary cross entropy loss for spectrogram. All the values in the spectrogram are treated as logits in a logistic regression.

        Args:
C
chenfeiyu 已提交
153 154
            prediction (Variable): shape(B, T, C), dtype float32, predicted spectrogram.
            target (Variable): shape(B, T, C), dtype float32, target spectrogram.
155 156 157
            mask (Variable): shape(B, T), mask.

        Returns:
C
chenfeiyu 已提交
158
            Variable: shape(1,), dtype float32, binary cross entropy loss.
159
        """
C
chenfeiyu 已提交
160 161
        flattened_prediction = F.reshape(prediction, [-1, 1])
        flattened_target = F.reshape(target, [-1, 1])
L
lifuchen 已提交
162 163
        flattened_loss = F.log_loss(
            flattened_prediction, flattened_target, epsilon=1e-8)
C
chenfeiyu 已提交
164 165 166 167 168 169 170 171 172 173 174 175
        bin_div = fluid.layers.reshape(flattened_loss, prediction.shape)

        w = self.masked_weight
        if w > 0 and mask is not None:
            loss = w * masked_mean(bin_div, mask) \
                 + (1 - w) * F.reduce_mean(bin_div)
        else:
            loss = F.reduce_mean(bin_div)
        return loss

    @staticmethod
    def done_loss(done_hat, done):
176 177 178
        """Compute done loss

        Args:
C
chenfeiyu 已提交
179 180
            done_hat (Variable): shape(B, T), dtype float32, predicted done probability(the probability that the final frame has been generated.)
            done (Variable): shape(B, T), dtype float32, ground truth done probability(the probability that the final frame has been generated.)
181 182

        Returns:
C
chenfeiyu 已提交
183
            Variable: shape(1, ), dtype float32, done loss.
184
        """
C
chenfeiyu 已提交
185 186 187 188 189 190 191 192 193
        flat_done_hat = F.reshape(done_hat, [-1, 1])
        flat_done = F.reshape(done, [-1, 1])
        loss = F.log_loss(flat_done_hat, flat_done, epsilon=1e-8)
        loss = F.reduce_mean(loss)
        return loss

    def attention_loss(self, predicted_attention, input_lengths,
                       target_lengths):
        """
194
        Given valid encoder_lengths and decoder_lengths, compute a diagonal guide, and compute loss from the predicted attention and the guide.
C
chenfeiyu 已提交
195 196
        
        Args:
C
chenfeiyu 已提交
197
            predicted_attention (Variable): shape(*, B, T_dec, T_enc), dtype float32, the alignment tensor, where B means batch size, T_dec means number of time steps of the decoder, T_enc means the number of time steps of the encoder, * means other possible dimensions.
198 199
            input_lengths (numpy.ndarray): shape(B,), dtype:int64, valid lengths (time steps) of encoder outputs.
            target_lengths (numpy.ndarray): shape(batch_size,), dtype:int64, valid lengths (time steps) of decoder outputs.
C
chenfeiyu 已提交
200 201
        
        Returns:
C
chenfeiyu 已提交
202
            loss (Variable): shape(1, ), dtype float32, attention loss.
C
chenfeiyu 已提交
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
        """
        n_attention, batch_size, max_target_len, max_input_len = (
            predicted_attention.shape)
        soft_mask = guided_attentions(input_lengths, target_lengths,
                                      max_target_len,
                                      self.guided_attention_sigma)
        soft_mask_ = dg.to_variable(soft_mask)
        loss = fluid.layers.reduce_mean(predicted_attention * soft_mask_)
        return loss

    def __call__(self,
                 mel_hyp,
                 lin_hyp,
                 done_hyp,
                 attn_hyp,
                 mel_ref,
                 lin_ref,
                 done_ref,
                 input_lengths,
                 n_frames,
                 compute_lin_loss=True,
                 compute_mel_loss=True,
                 compute_done_loss=True,
                 compute_attn_loss=True):
227 228 229
        """Total loss

        Args:
C
chenfeiyu 已提交
230 231 232 233 234 235 236
            mel_hyp (Variable): shape(B, T, C_mel), dtype float32, predicted mel spectrogram.
            lin_hyp (Variable): shape(B, T, C_lin), dtype float32, predicted linear spectrogram.
            done_hyp (Variable): shape(B, T), dtype float32, predicted done probability.
            attn_hyp (Variable): shape(N, B, T_dec, T_enc), dtype float32, predicted attention.
            mel_ref (Variable): shape(B, T, C_mel), dtype float32, ground truth mel spectrogram.
            lin_ref (Variable): shape(B, T, C_lin), dtype float32, ground truth linear spectrogram.
            done_ref (Variable): shape(B, T), dtype float32, ground truth done flag.
237 238 239 240 241 242 243 244 245 246
            input_lengths (Variable): shape(B, ), dtype: int, encoder valid lengths.
            n_frames (Variable): shape(B, ), dtype: int, decoder valid lengths.
            compute_lin_loss (bool, optional): whether to compute linear loss. Defaults to True.
            compute_mel_loss (bool, optional): whether to compute mel loss. Defaults to True.
            compute_done_loss (bool, optional): whether to compute done loss. Defaults to True.
            compute_attn_loss (bool, optional): whether to compute atention loss. Defaults to True.

        Returns:
            Dict(str, Variable): details of loss.
        """
247 248
        total_loss = 0.

C
chenfeiyu 已提交
249 250 251 252 253
        # n_frames # mel_lengths # decoder_lengths
        max_frames = lin_hyp.shape[1]
        max_mel_steps = max_frames // self.downsample_factor
        max_decoder_steps = max_mel_steps // self.r

L
lifuchen 已提交
254 255 256 257 258 259
        decoder_mask = F.sequence_mask(
            n_frames // self.downsample_factor // self.r,
            max_decoder_steps,
            dtype="float32")
        mel_mask = F.sequence_mask(
            n_frames // self.downsample_factor, max_mel_steps, dtype="float32")
C
chenfeiyu 已提交
260 261 262 263 264 265
        lin_mask = F.sequence_mask(n_frames, max_frames, dtype="float32")

        if compute_lin_loss:
            lin_hyp = lin_hyp[:, :-self.time_shift, :]
            lin_ref = lin_ref[:, self.time_shift:, :]
            lin_mask = lin_mask[:, self.time_shift:, :]
L
lifuchen 已提交
266 267
            lin_l1_loss = self.l1_loss(
                lin_hyp, lin_ref, lin_mask, priority_bin=self.priority_bin)
C
chenfeiyu 已提交
268 269 270
            lin_bce_loss = self.binary_divergence(lin_hyp, lin_ref, lin_mask)
            lin_loss = self.binary_divergence_weight * lin_bce_loss \
                     + (1 - self.binary_divergence_weight) * lin_l1_loss
271
            total_loss += lin_loss
C
chenfeiyu 已提交
272 273 274 275 276 277 278 279 280 281

        if compute_mel_loss:
            mel_hyp = mel_hyp[:, :-self.time_shift, :]
            mel_ref = mel_ref[:, self.time_shift:, :]
            mel_mask = mel_mask[:, self.time_shift:, :]
            mel_l1_loss = self.l1_loss(mel_hyp, mel_ref, mel_mask)
            mel_bce_loss = self.binary_divergence(mel_hyp, mel_ref, mel_mask)
            # print("=====>", mel_l1_loss.numpy()[0], mel_bce_loss.numpy()[0])
            mel_loss = self.binary_divergence_weight * mel_bce_loss \
                     + (1 - self.binary_divergence_weight) * mel_l1_loss
282
            total_loss += mel_loss
C
chenfeiyu 已提交
283 284

        if compute_attn_loss:
L
lifuchen 已提交
285 286 287 288
            attn_loss = self.attention_loss(attn_hyp,
                                            input_lengths.numpy(),
                                            n_frames.numpy() //
                                            (self.downsample_factor * self.r))
289
            total_loss += attn_loss
C
chenfeiyu 已提交
290 291 292

        if compute_done_loss:
            done_loss = self.done_loss(done_hyp, done_ref)
293
            total_loss += done_loss
C
chenfeiyu 已提交
294 295

        result = {
296 297 298 299 300 301 302
            "loss": total_loss,
            "mel/mel_loss": mel_loss if compute_mel_loss else None,
            "mel/l1_loss": mel_l1_loss if compute_mel_loss else None,
            "mel/bce_loss": mel_bce_loss if compute_mel_loss else None,
            "lin/lin_loss": lin_loss if compute_lin_loss else None,
            "lin/l1_loss": lin_l1_loss if compute_lin_loss else None,
            "lin/bce_loss": lin_bce_loss if compute_lin_loss else None,
C
chenfeiyu 已提交
303 304 305
            "done": done_loss if compute_done_loss else None,
            "attn": attn_loss if compute_attn_loss else None,
        }
306

C
chenfeiyu 已提交
307
        return result