losses.py 4.4 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16
import numpy as np

L
fix nan  
LielinJiang 已提交
17 18
import paddle
import paddle.nn as nn
19
import paddle.nn.functional as F
L
LielinJiang 已提交
20

21

L
fix nan  
LielinJiang 已提交
22
class GANLoss(nn.Layer):
L
LielinJiang 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
    """Define different GAN objectives.

    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """
    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.

        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image

        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
40 41
        self.target_real_label = target_real_label
        self.target_fake_label = target_fake_label
L
LielinJiang 已提交
42 43 44 45 46

        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
L
fix nan  
LielinJiang 已提交
47
            self.loss = nn.BCEWithLogitsLoss()
48
        elif gan_mode in ['wgan', 'wgangp', 'hinge', 'logistic']:
L
LielinJiang 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.

        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """
        if target_is_real:
64
            if not hasattr(self, 'target_real_tensor'):
65
                self.target_real_tensor = paddle.full(
littletomatodonkey's avatar
littletomatodonkey 已提交
66
                    shape=paddle.shape(prediction),
67
                    fill_value=self.target_real_label,
littletomatodonkey's avatar
littletomatodonkey 已提交
68
                    dtype='float32')
69
            target_tensor = self.target_real_tensor
L
LielinJiang 已提交
70
        else:
71
            if not hasattr(self, 'target_fake_tensor'):
72
                self.target_fake_tensor = paddle.full(
littletomatodonkey's avatar
littletomatodonkey 已提交
73
                    shape=paddle.shape(prediction),
74
                    fill_value=self.target_fake_label,
littletomatodonkey's avatar
littletomatodonkey 已提交
75
                    dtype='float32')
76
            target_tensor = self.target_fake_tensor
L
LielinJiang 已提交
77 78

        # target_tensor.stop_gradient = True
79
        return target_tensor
L
LielinJiang 已提交
80

81
    def __call__(self, prediction, target_is_real, is_updating_D=None):
L
LielinJiang 已提交
82 83 84 85 86
        """Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images
87
            is_updating_D (bool)  - - if we are in updating D step or not 
L
LielinJiang 已提交
88 89 90 91 92 93 94

        Returns:
            the calculated loss.
        """
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            loss = self.loss(prediction, target_tensor)
95
        elif self.gan_mode.find('wgan') != -1:
L
LielinJiang 已提交
96 97 98 99
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
100 101 102 103 104 105 106 107 108 109 110
        elif self.gan_mode == 'hinge':
            if target_is_real:
                loss = F.relu(1 - prediction) if is_updating_D else -prediction
            else:
                loss = F.relu(1 + prediction) if is_updating_D else prediction
            loss = loss.mean()
        elif self.gan_mode == 'logistic':
            if target_is_real:
                loss = F.softplus(-prediction).mean()
            else:
                loss = F.softplus(prediction).mean()
L
fix nan  
LielinJiang 已提交
111
        return loss