__init__.py 4.8 KB
Newer Older
V
Varuna Jayasiri 已提交
1
"""
V
Varuna Jayasiri 已提交
2 3 4 5 6
---
title: Generative Adversarial Networks (GAN)
summary: A simple PyTorch implementation/tutorial of Generative Adversarial Networks (GAN) loss functions.
---

V
titles  
Varuna Jayasiri 已提交
7 8
# Generative Adversarial Networks (GAN)

V
Varuna Jayasiri 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
This is an implementation of
[Generative Adversarial Networks](https://arxiv.org/abs/1406.2661).

The generator, $G(\pmb{z}; \theta_g)$ generates samples that match the
distribution of data, while the discriminator, $D(\pmb{x}; \theta_g)$
gives the probability that $\pmb{x}$ came from data rather than $G$.

We train $D$ and $G$ simultaneously on a two-player min-max game with value
function $V(G, D)$.

$$\min_G \max_D V(D, G) =
    \mathop{\mathbb{E}}_{\pmb{x} \sim p_{data}(\pmb{x})}
        \big[\log D(\pmb{x})\big] +
    \mathop{\mathbb{E}}_{\pmb{z} \sim p_{\pmb{z}}(\pmb{z})}
        \big[\log (1 - D(G(\pmb{z}))\big]
$$

$p_{data}(\pmb{x})$ is the probability distribution over data,
whilst $p_{\pmb{z}}(\pmb{z})$ probability distribution of $\pmb{z}$, which is set to
gaussian noise.

V
Varuna Jayasiri 已提交
30
This file defines the loss functions. [Here](simple_mnist_experiment.html) is an MNIST example
V
Varuna Jayasiri 已提交
31 32 33
with two multilayer perceptron for the generator and discriminator.
"""

V
Varuna Jayasiri 已提交
34 35 36 37 38 39 40 41
import torch
import torch.nn as nn
import torch.utils.data
import torch.utils.data

from labml_helpers.module import Module


V
Varuna Jayasiri 已提交
42 43 44
class DiscriminatorLogitsLoss(Module):
    """
    ## Discriminator Loss
V
Varuna Jayasiri 已提交
45

V
Varuna Jayasiri 已提交
46 47 48 49 50 51 52 53 54 55
    Discriminator should **ascend** on the gradient,

    $$\nabla_{\theta_d} \frac{1}{m} \sum_{i=1}^m \Bigg[
        \log D\Big(\pmb{x}^{(i)}\Big) +
        \log \Big(1 - D\Big(G\Big(\pmb{z}^{(i)}\Big)\Big)\Big)
    \Bigg]$$

    $m$ is the mini-batch size and $(i)$ is used to index samples in the mini-batch.
    $\pmb{x}$ are samples from $p_{data}$ and $\pmb{z}$ are samples from $p_z$.
    """
V
Varuna Jayasiri 已提交
56 57

    def __init__(self, smoothing: float = 0.2):
V
Varuna Jayasiri 已提交
58
        super().__init__()
V
Varuna Jayasiri 已提交
59 60 61 62 63 64
        # We use PyTorch Binary Cross Entropy Loss, which is
        # $-\sum\Big[y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})\Big]$,
        # where $y$ are the labels and $\hat{y}$ are the predictions.
        # *Note the negative sign*.
        # We use labels equal to $1$ for $\pmb{x}$ from $p_{data}$
        # and labels equal to $0$ for $\pmb{x}$ from $p_{G}.$
K
KeshSam 已提交
65
        # Then descending on the sum of these is the same as ascending on
V
Varuna Jayasiri 已提交
66 67 68
        # the above gradient.
        #
        # `BCEWithLogitsLoss` combines softmax and binary cross entropy loss.
V
Varuna Jayasiri 已提交
69 70
        self.loss_true = nn.BCEWithLogitsLoss()
        self.loss_false = nn.BCEWithLogitsLoss()
V
Varuna Jayasiri 已提交
71 72

        # We use label smoothing because it seems to work better in some cases
V
Varuna Jayasiri 已提交
73
        self.smoothing = smoothing
V
Varuna Jayasiri 已提交
74 75 76 77

        # Labels are registered as buffered and persistence is set to `False`.
        self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
        self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
V
Varuna Jayasiri 已提交
78 79

    def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
V
Varuna Jayasiri 已提交
80 81 82 83
        """
        `logits_true` are logits from $D(\pmb{x}^{(i)})$ and
        `logits_false` are logits from $D(G(\pmb{z}^{(i)}))$
        """
V
Varuna Jayasiri 已提交
84 85
        if len(logits_true) > len(self.labels_true):
            self.register_buffer("labels_true",
V
Varuna Jayasiri 已提交
86
                                 _create_labels(len(logits_true), 1.0 - self.smoothing, 1.0, logits_true.device), False)
V
Varuna Jayasiri 已提交
87 88
        if len(logits_false) > len(self.labels_false):
            self.register_buffer("labels_false",
V
Varuna Jayasiri 已提交
89
                                 _create_labels(len(logits_false), 0.0, self.smoothing, logits_false.device), False)
V
Varuna Jayasiri 已提交
90

V
Varuna Jayasiri 已提交
91 92
        return (self.loss_true(logits_true, self.labels_true[:len(logits_true)]),
                self.loss_false(logits_false, self.labels_false[:len(logits_false)]))
V
Varuna Jayasiri 已提交
93 94 95


class GeneratorLogitsLoss(Module):
V
Varuna Jayasiri 已提交
96 97 98 99 100 101 102 103 104
    """
    ## Generator Loss

    Generator should **descend** on the gradient,

    $$\nabla_{\theta_g} \frac{1}{m} \sum_{i=1}^m \Bigg[
        \log \Big(1 - D\Big(G\Big(\pmb{z}^{(i)}\Big)\Big)\Big)
    \Bigg]$$
    """
V
Varuna Jayasiri 已提交
105
    def __init__(self, smoothing: float = 0.2):
V
Varuna Jayasiri 已提交
106 107
        super().__init__()
        self.loss_true = nn.BCEWithLogitsLoss()
V
Varuna Jayasiri 已提交
108
        self.smoothing = smoothing
V
Varuna Jayasiri 已提交
109
        # We use labels equal to $1$ for $\pmb{x}$ from $p_{G}.$
K
KeshSam 已提交
110
        # Then descending on this loss is the same as descending on
V
Varuna Jayasiri 已提交
111 112
        # the above gradient.
        self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
V
Varuna Jayasiri 已提交
113 114 115 116

    def __call__(self, logits: torch.Tensor):
        if len(logits) > len(self.fake_labels):
            self.register_buffer("fake_labels",
V
Varuna Jayasiri 已提交
117
                                 _create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)
V
Varuna Jayasiri 已提交
118 119

        return self.loss_true(logits, self.fake_labels[:len(logits)])
V
Varuna Jayasiri 已提交
120 121 122 123 124 125 126


def _create_labels(n: int, r1: float, r2: float, device: torch.device = None):
    """
    Create smoothed labels
    """
    return torch.empty(n, 1, requires_grad=False, device=device).uniform_(r1, r2)