losses.py 3.1 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6
import paddle
import paddle.nn as nn
import numpy as np

from ..modules.nn import BCEWithLogitsLoss

7

L
LielinJiang 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
class GANLoss(paddle.fluid.dygraph.Layer):
    """Define different GAN objectives.

    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.

        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image

        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
27 28
        self.target_real_label = target_real_label
        self.target_fake_label = target_fake_label
L
LielinJiang 已提交
29 30 31 32 33

        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
34
            self.loss = BCEWithLogitsLoss()
L
LielinJiang 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
        elif gan_mode in ['wgangp']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.

        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """

        if target_is_real:
52 53 54
            if not hasattr(self, 'target_real_tensor'):
                self.target_real_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=self.target_real_label, dtype='float32')
            target_tensor = self.target_real_tensor
L
LielinJiang 已提交
55
        else:
56 57 58
            if not hasattr(self, 'target_fake_tensor'):
                self.target_fake_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=self.target_fake_label, dtype='float32')
            target_tensor = self.target_fake_tensor
L
LielinJiang 已提交
59 60

        # target_tensor.stop_gradient = True
61
        return target_tensor
L
LielinJiang 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81

    def __call__(self, prediction, target_is_real):
        """Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            loss = self.loss(prediction, target_tensor)
        elif self.gan_mode == 'wgangp':
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
        return loss