losses.py 3.2 KB
Newer Older
L
LielinJiang 已提交
1 2
import numpy as np

L
fix nan  
LielinJiang 已提交
3 4
import paddle
import paddle.nn as nn
L
LielinJiang 已提交
5

6

L
fix nan  
LielinJiang 已提交
7
class GANLoss(nn.Layer):
L
LielinJiang 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
    """Define different GAN objectives.

    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """
    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.

        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image

        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
25 26
        self.target_real_label = target_real_label
        self.target_fake_label = target_fake_label
L
LielinJiang 已提交
27 28 29 30 31

        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
L
fix nan  
LielinJiang 已提交
32
            self.loss = nn.BCEWithLogitsLoss()
L
LielinJiang 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
        elif gan_mode in ['wgangp']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.

        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """
        if target_is_real:
49
            if not hasattr(self, 'target_real_tensor'):
50
                self.target_real_tensor = paddle.full(
L
fix nan  
LielinJiang 已提交
51
                    shape=paddle.shape(prediction),
52
                    fill_value=self.target_real_label,
L
fix nan  
LielinJiang 已提交
53
                    dtype='float32')
54
            target_tensor = self.target_real_tensor
L
LielinJiang 已提交
55
        else:
56
            if not hasattr(self, 'target_fake_tensor'):
57
                self.target_fake_tensor = paddle.full(
L
fix nan  
LielinJiang 已提交
58
                    shape=paddle.shape(prediction),
59
                    fill_value=self.target_fake_label,
L
fix nan  
LielinJiang 已提交
60
                    dtype='float32')
61
            target_tensor = self.target_fake_tensor
L
LielinJiang 已提交
62 63

        # target_tensor.stop_gradient = True
64
        return target_tensor
L
LielinJiang 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83

    def __call__(self, prediction, target_is_real):
        """Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            loss = self.loss(prediction, target_tensor)
        elif self.gan_mode == 'wgangp':
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
L
fix nan  
LielinJiang 已提交
84
        return loss