__init__.py 4.6 KB
Newer Older
V
Varuna Jayasiri 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
"""
This is an implementation of
[Generative Adversarial Networks](https://arxiv.org/abs/1406.2661).

The generator, $G(\pmb{z}; \theta_g)$ generates samples that match the
distribution of data, while the discriminator, $D(\pmb{x}; \theta_g)$
gives the probability that $\pmb{x}$ came from data rather than $G$.

We train $D$ and $G$ simultaneously on a two-player min-max game with value
function $V(G, D)$.

$$\min_G \max_D V(D, G) =
    \mathop{\mathbb{E}}_{\pmb{x} \sim p_{data}(\pmb{x})}
        \big[\log D(\pmb{x})\big] +
    \mathop{\mathbb{E}}_{\pmb{z} \sim p_{\pmb{z}}(\pmb{z})}
        \big[\log (1 - D(G(\pmb{z}))\big]
$$

$p_{data}(\pmb{x})$ is the probability distribution over data,
whilst $p_{\pmb{z}}(\pmb{z})$ probability distribution of $\pmb{z}$, which is set to
gaussian noise.

This file defines the loss functions. [Here](gan_mnist.html) is an MNIST example
with two multilayer perceptron for the generator and discriminator.
"""

V
Varuna Jayasiri 已提交
27 28 29 30 31 32 33 34
import torch
import torch.nn as nn
import torch.utils.data
import torch.utils.data

from labml_helpers.module import Module


V
Varuna Jayasiri 已提交
35 36 37
class DiscriminatorLogitsLoss(Module):
    """
    ## Discriminator Loss
V
Varuna Jayasiri 已提交
38

V
Varuna Jayasiri 已提交
39 40 41 42 43 44 45 46 47 48
    Discriminator should **ascend** on the gradient,

    $$\nabla_{\theta_d} \frac{1}{m} \sum_{i=1}^m \Bigg[
        \log D\Big(\pmb{x}^{(i)}\Big) +
        \log \Big(1 - D\Big(G\Big(\pmb{z}^{(i)}\Big)\Big)\Big)
    \Bigg]$$

    $m$ is the mini-batch size and $(i)$ is used to index samples in the mini-batch.
    $\pmb{x}$ are samples from $p_{data}$ and $\pmb{z}$ are samples from $p_z$.
    """
V
Varuna Jayasiri 已提交
49 50

    def __init__(self, smoothing: float = 0.2):
V
Varuna Jayasiri 已提交
51
        super().__init__()
V
Varuna Jayasiri 已提交
52 53 54 55 56 57 58 59 60 61
        # We use PyTorch Binary Cross Entropy Loss, which is
        # $-\sum\Big[y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})\Big]$,
        # where $y$ are the labels and $\hat{y}$ are the predictions.
        # *Note the negative sign*.
        # We use labels equal to $1$ for $\pmb{x}$ from $p_{data}$
        # and labels equal to $0$ for $\pmb{x}$ from $p_{G}.$
        # Then descending on the sum of these is same as ascending on
        # the above gradient.
        #
        # `BCEWithLogitsLoss` combines softmax and binary cross entropy loss.
V
Varuna Jayasiri 已提交
62 63
        self.loss_true = nn.BCEWithLogitsLoss()
        self.loss_false = nn.BCEWithLogitsLoss()
V
Varuna Jayasiri 已提交
64 65

        # We use label smoothing because it seems to work better in some cases
V
Varuna Jayasiri 已提交
66
        self.smoothing = smoothing
V
Varuna Jayasiri 已提交
67 68 69 70

        # Labels are registered as buffered and persistence is set to `False`.
        self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
        self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
V
Varuna Jayasiri 已提交
71 72

    def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
V
Varuna Jayasiri 已提交
73 74 75 76
        """
        `logits_true` are logits from $D(\pmb{x}^{(i)})$ and
        `logits_false` are logits from $D(G(\pmb{z}^{(i)}))$
        """
V
Varuna Jayasiri 已提交
77 78
        if len(logits_true) > len(self.labels_true):
            self.register_buffer("labels_true",
V
Varuna Jayasiri 已提交
79
                                 _create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
V
Varuna Jayasiri 已提交
80 81
        if len(logits_false) > len(self.labels_false):
            self.register_buffer("labels_false",
V
Varuna Jayasiri 已提交
82
                                 _create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
V
Varuna Jayasiri 已提交
83

V
Varuna Jayasiri 已提交
84 85
        return (self.loss_true(logits_true, self.labels_true[:len(logits_true)]),
                self.loss_false(logits_false, self.labels_false[:len(logits_false)]))
V
Varuna Jayasiri 已提交
86 87 88


class GeneratorLogitsLoss(Module):
V
Varuna Jayasiri 已提交
89 90 91 92 93 94 95 96 97
    """
    ## Generator Loss

    Generator should **descend** on the gradient,

    $$\nabla_{\theta_g} \frac{1}{m} \sum_{i=1}^m \Bigg[
        \log \Big(1 - D\Big(G\Big(\pmb{z}^{(i)}\Big)\Big)\Big)
    \Bigg]$$
    """
V
Varuna Jayasiri 已提交
98
    def __init__(self, smoothing: float = 0.2):
V
Varuna Jayasiri 已提交
99 100
        super().__init__()
        self.loss_true = nn.BCEWithLogitsLoss()
V
Varuna Jayasiri 已提交
101
        self.smoothing = smoothing
V
Varuna Jayasiri 已提交
102 103 104 105
        # We use labels equal to $1$ for $\pmb{x}$ from $p_{G}.$
        # Then descending on this loss is same as descending on
        # the above gradient.
        self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
V
Varuna Jayasiri 已提交
106 107 108 109

    def __call__(self, logits: torch.Tensor):
        if len(logits) > len(self.fake_labels):
            self.register_buffer("fake_labels",
V
Varuna Jayasiri 已提交
110
                                 _create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
V
Varuna Jayasiri 已提交
111 112

        return self.loss_true(logits, self.fake_labels[:len(logits)])
V
Varuna Jayasiri 已提交
113 114 115 116 117 118 119


def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
    """
    Create smoothed labels
    """
    return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)