losses.py 3.2 KB
Newer Older
L
LielinJiang 已提交
1 2
import numpy as np

L
fix nan  
LielinJiang 已提交
3 4
import paddle
import paddle.nn as nn
L
LielinJiang 已提交
5

6

L
fix nan  
LielinJiang 已提交
7
class GANLoss(nn.Layer):
L
LielinJiang 已提交
8 9 10 11 12
    """Define different GAN objectives.

    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """
littletomatodonkey's avatar
littletomatodonkey 已提交
13

L
LielinJiang 已提交
14 15 16 17 18 19 20 21 22 23 24 25
    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.

        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image

        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
26 27
        self.target_real_label = target_real_label
        self.target_fake_label = target_fake_label
L
LielinJiang 已提交
28 29 30 31 32

        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
L
fix nan  
LielinJiang 已提交
33
            self.loss = nn.BCEWithLogitsLoss()
L
LielinJiang 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
        elif gan_mode in ['wgangp']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.

        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """
        if target_is_real:
50
            if not hasattr(self, 'target_real_tensor'):
littletomatodonkey's avatar
littletomatodonkey 已提交
51 52 53 54
                self.target_real_tensor = paddle.fill_constant(
                    shape=paddle.shape(prediction),
                    value=self.target_real_label,
                    dtype='float32')
55
            target_tensor = self.target_real_tensor
L
LielinJiang 已提交
56
        else:
57
            if not hasattr(self, 'target_fake_tensor'):
littletomatodonkey's avatar
littletomatodonkey 已提交
58 59 60 61
                self.target_fake_tensor = paddle.fill_constant(
                    shape=paddle.shape(prediction),
                    value=self.target_fake_label,
                    dtype='float32')
62
            target_tensor = self.target_fake_tensor
L
LielinJiang 已提交
63 64

        # target_tensor.stop_gradient = True
65
        return target_tensor
L
LielinJiang 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

    def __call__(self, prediction, target_is_real):
        """Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            loss = self.loss(prediction, target_tensor)
        elif self.gan_mode == 'wgangp':
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
L
fix nan  
LielinJiang 已提交
85
        return loss