import paddle import paddle.nn as nn import numpy as np from ..modules.nn import BCEWithLogitsLoss class GANLoss(paddle.fluid.dygraph.Layer): """Define different GAN objectives. The GANLoss class abstracts away the need to create the target label tensor that has the same size as the input. """ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): """ Initialize the GANLoss class. Parameters: gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. target_real_label (bool) - - label for a real image target_fake_label (bool) - - label of a fake image Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() self.target_real_label = target_real_label self.target_fake_label = target_fake_label self.gan_mode = gan_mode if gan_mode == 'lsgan': self.loss = nn.MSELoss() elif gan_mode == 'vanilla': self.loss = BCEWithLogitsLoss() elif gan_mode in ['wgangp']: self.loss = None else: raise NotImplementedError('gan mode %s not implemented' % gan_mode) def get_target_tensor(self, prediction, target_is_real): """Create label tensors with the same size as the input. Parameters: prediction (tensor) - - tpyically the prediction from a discriminator target_is_real (bool) - - if the ground truth label is for real images or fake images Returns: A label tensor filled with ground truth label, and with the size of the input """ if target_is_real: if not hasattr(self, 'target_real_tensor'): self.target_real_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=self.target_real_label, dtype='float32') target_tensor = self.target_real_tensor else: if not hasattr(self, 'target_fake_tensor'): self.target_fake_tensor = paddle.fill_constant(shape=paddle.shape(prediction), value=self.target_fake_label, dtype='float32') target_tensor = self.target_fake_tensor # target_tensor.stop_gradient = True return target_tensor def __call__(self, prediction, target_is_real): """Calculate loss given Discriminator's output and grount truth labels. Parameters: prediction (tensor) - - tpyically the prediction output from a discriminator target_is_real (bool) - - if the ground truth label is for real images or fake images Returns: the calculated loss. """ if self.gan_mode in ['lsgan', 'vanilla']: target_tensor = self.get_target_tensor(prediction, target_is_real) loss = self.loss(prediction, target_tensor) elif self.gan_mode == 'wgangp': if target_is_real: loss = -prediction.mean() else: loss = prediction.mean() return loss