__init__.py 8.4 KB
Newer Older
V
Varuna Jayasiri 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
"""
---
title: GPT
summary: >
  Implementation/tutorial of GPT model and training code.
---

# GPT

This is an tutorial of
[OpenAI GPT architecture](https://openai.com/blog/better-language-models/).
We got a bunch of implementation details from
[minGPT](https://github.com/karpathy/minGPT)
by [@karpathy](https://twitter.com/karpathy).
This implementation also uses character tiny shakespeare dataset.

GPT model is essentially a standard transformer with a few tweaks.
GPT-2 and especially GPT-3 models are quite large and won't fit on a
single GPU and will need model parallelism.
This implementation doesn't even use data parallelism and is intended to be
more of a tutorial.

Main differences of this to a standard autoregressive transformer
are the parameter initialization, weight decay, and learning rate schedule.
For the transformer we reuse the
[existing labml/nn transformer implementation](https://lab-ml.com/labml_nn/transformers/).
V
Varuna Jayasiri 已提交
27 28 29 30 31

Here's a notebook for training a GPT mode on Tiny Shakespeare dataset.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/gpt/experiment.ipynb)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=0324c6d0562111eba65d0242ac1c0002)
V
Varuna Jayasiri 已提交
32 33
"""

V
gpt  
Varuna Jayasiri 已提交
34
import torch
V
Varuna Jayasiri 已提交
35 36
from torch import nn

V
gpt  
Varuna Jayasiri 已提交
37 38 39 40 41 42 43 44 45 46
from labml import experiment
from labml.configs import option
from labml_helpers.module import Module
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.optimizers.configs import OptimizerConfigs
from labml_nn.transformers import TransformerConfigs, Encoder
from labml_nn.transformers.utils import subsequent_mask


class GPT(Module):
V
Varuna Jayasiri 已提交
47 48 49 50 51 52
    """
    ## GPT model

    This consists of a token embedding layer, transformer encoder, and
    a final linear layer that gives token logits.
    """
V
Varuna Jayasiri 已提交
53

V
gpt  
Varuna Jayasiri 已提交
54
    def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
V
Varuna Jayasiri 已提交
55 56 57 58 59 60
        """
        * `encoder` is the transformer [Encoder](../models.html#Encoder)
        * `src_embed` is the token
        [embedding module (with positional encodings)](../models.html#EmbeddingsWithLearnedPositionalEncoding)
        * `generator` is the [final fully connected layer](../models.html#Generator) that gives the logits.
        """
V
gpt  
Varuna Jayasiri 已提交
61 62 63 64
        super().__init__()
        self.src_embed = src_embed
        self.encoder = encoder
        self.generator = generator
V
Varuna Jayasiri 已提交
65 66

        # The mask will be initialized on the first call
V
gpt  
Varuna Jayasiri 已提交
67 68 69
        self.mask = None

    def __call__(self, x: torch.Tensor):
V
Varuna Jayasiri 已提交
70 71
        # Create subsequent mask if mask is not initialized
        # or if the size of the mask is different
V
gpt  
Varuna Jayasiri 已提交
72
        if self.mask is None or self.mask.size(0) != len(x):
V
Varuna Jayasiri 已提交
73
            # Subsequent mask, will mask out tokens from seeing future tokens
V
gpt  
Varuna Jayasiri 已提交
74
            self.mask = subsequent_mask(len(x)).to(x.device)
V
Varuna Jayasiri 已提交
75
        # Get the token embeddings with positional encodings
V
gpt  
Varuna Jayasiri 已提交
76
        x = self.src_embed(x)
V
Varuna Jayasiri 已提交
77
        # Transformer encoder
V
gpt  
Varuna Jayasiri 已提交
78
        x = self.encoder(x, self.mask)
V
Varuna Jayasiri 已提交
79
        # Get logits
V
gpt  
Varuna Jayasiri 已提交
80 81
        x = self.generator(x)

V
Varuna Jayasiri 已提交
82 83
        # Return results
        # (second value is for state, since our trainer is used with RNNs also)
V
gpt  
Varuna Jayasiri 已提交
84 85 86 87
        return x, None


class Configs(NLPAutoRegressionConfigs):
V
Varuna Jayasiri 已提交
88 89 90
    """
    ## Configurations

V
Varuna Jayasiri 已提交
91 92
    This inherits from
    [`NLPAutoRegressionConfigs`](../../experiments/nlp_autoregression.html#NLPAutoRegressionConfigs)
V
Varuna Jayasiri 已提交
93
    """
V
Varuna Jayasiri 已提交
94 95

    # GPT model
V
gpt  
Varuna Jayasiri 已提交
96
    model: GPT
V
Varuna Jayasiri 已提交
97
    # Transformer
V
gpt  
Varuna Jayasiri 已提交
98
    transformer: TransformerConfigs
V
Varuna Jayasiri 已提交
99
    # Weight decay
V
gpt  
Varuna Jayasiri 已提交
100
    weight_decay: float = 0.1
V
Varuna Jayasiri 已提交
101
    # Number of tokens for wamup
V
Varuna Jayasiri 已提交
102
    warmup_steps: int = 128 * 128 * 20
V
gpt  
Varuna Jayasiri 已提交
103

V
Varuna Jayasiri 已提交
104
    # Custom optimizer
V
gpt  
Varuna Jayasiri 已提交
105 106 107 108 109
    optimizer = 'transformer_optimizer'


@option(Configs.transformer, 'GPT')
def _transformer_configs(c: Configs):
V
Varuna Jayasiri 已提交
110 111 112 113 114 115
    """
    ### Transformer configurations
    """

    # We use our
    # [configurable transformer implementation](../configs.html#TransformerConfigs)
V
gpt  
Varuna Jayasiri 已提交
116
    conf = TransformerConfigs()
V
Varuna Jayasiri 已提交
117
    # Set the vocabulary sizes for embeddings and generating logits
V
gpt  
Varuna Jayasiri 已提交
118 119
    conf.n_src_vocab = c.n_tokens
    conf.n_tgt_vocab = c.n_tokens
V
Varuna Jayasiri 已提交
120
    # GPT uses GELU activation for position wise feedforward
V
gpt  
Varuna Jayasiri 已提交
121 122
    conf.feed_forward_activation = 'GELU'

V
Varuna Jayasiri 已提交
123
    #
V
gpt  
Varuna Jayasiri 已提交
124 125 126 127
    return conf


def _init_weights(module):
V
Varuna Jayasiri 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    """
    ### Initialize weights

    Weights of linear layers and embedding layers are initialized
    to $\mathcal{N}(0, 0.02)$
    instead of the default Xavier initialzation.
    """

    if not isinstance(module, (nn.Linear, nn.Embedding)):
        return

    module.weight.data.normal_(mean=0.0, std=0.02)

    # Initialize biases to $0$
    if isinstance(module, nn.Linear) and module.bias is not None:
V
gpt  
Varuna Jayasiri 已提交
143 144 145 146 147
        module.bias.data.zero_()


@option(Configs.model)
def _model(c: Configs):
V
Varuna Jayasiri 已提交
148 149 150
    """
    Create GPT model and initialize weights
    """
V
gpt  
Varuna Jayasiri 已提交
151 152 153 154
    m = GPT(c.transformer.encoder,
            c.transformer.src_embed,
            c.transformer.generator).to(c.device)

V
Varuna Jayasiri 已提交
155
    # Apply custom weight initialization
V
gpt  
Varuna Jayasiri 已提交
156 157 158 159 160 161 162
    m.apply(_init_weights)

    return m


@option(NLPAutoRegressionConfigs.optimizer)
def transformer_optimizer(c: NLPAutoRegressionConfigs):
V
Varuna Jayasiri 已提交
163 164
    """
    ### Create custom optimizer with weight decay
V
gpt  
Varuna Jayasiri 已提交
165

V
Varuna Jayasiri 已提交
166 167 168 169
    This code is taken from [minGPT](https://github.com/karpathy/minGPT).
    This applies weight decay only to weights of linear layers.
    """
    # Collect names of parameters to apply weight decay
V
gpt  
Varuna Jayasiri 已提交
170 171 172 173 174
    decay = set()
    for mn, m in c.model.named_modules():
        for pn, p in m.named_parameters():
            fpn = f'{mn}.{pn}' if mn else pn  # full param name

V
Varuna Jayasiri 已提交
175 176
            if fpn.endswith('weight') and isinstance(m, nn.Linear):
                decay.add(fpn)
V
gpt  
Varuna Jayasiri 已提交
177

V
Varuna Jayasiri 已提交
178 179 180 181
    # Get all the parameters
    param_dict = {pn: p for pn, p in c.model.named_parameters()}
    # Parameters that are not decayed
    no_decay = set(param_dict.keys()) - decay
V
gpt  
Varuna Jayasiri 已提交
182 183 184 185 186 187 188

    # create the pytorch optimizer object
    opt_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": c.weight_decay},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
    ]

V
Varuna Jayasiri 已提交
189 190 191 192 193 194
    # Create a [configurable optimizer](../optimizers/configs.html#OptimizerConfigs),
    # so that we can change these simple by passing
    # a config dictionary.
    optimizer = OptimizerConfigs()

    # Set parameter groups for optimization
V
gpt  
Varuna Jayasiri 已提交
195
    optimizer.parameters = opt_groups
V
Varuna Jayasiri 已提交
196 197
    # Use [cosine decay optimizer](../optimizers/adam_warmup_cosine_decay.html)
    # This is what GPT uses
V
gpt  
Varuna Jayasiri 已提交
198
    optimizer.optimizer = 'AdamWarmupCosineDecay'
V
Varuna Jayasiri 已提交
199 200
    # Set model embedding size, required if we use [Noam optimizer](../optimizers/noam.html)
    # which has an exponential decay
V
gpt  
Varuna Jayasiri 已提交
201
    optimizer.d_model = c.d_model
V
Varuna Jayasiri 已提交
202 203
    # Set default weight decay.
    # This is not required since we set the weight decay in the parameter groups
V
gpt  
Varuna Jayasiri 已提交
204
    optimizer.weight_decay = c.weight_decay
V
Varuna Jayasiri 已提交
205
    # GPT uses a maximum learning rate of $6 \times 10^{-4}$
V
gpt  
Varuna Jayasiri 已提交
206
    optimizer.learning_rate = 6e-4
V
Varuna Jayasiri 已提交
207
    # $\beta_1 = 0.9, \beta_2 = 0.95$
V
gpt  
Varuna Jayasiri 已提交
208
    optimizer.betas = (0.9, 0.95)
V
Varuna Jayasiri 已提交
209
    # $\epsilon = 10^{-8}$
V
gpt  
Varuna Jayasiri 已提交
210
    optimizer.eps = 1e-8
V
Varuna Jayasiri 已提交
211
    # Weight decay decoupled from gradients
V
Varuna Jayasiri 已提交
212
    optimizer.weight_decouple = True
V
Varuna Jayasiri 已提交
213 214 215
    # Total number of optimization steps for learning rate cosine decay
    optimizer.total_steps = c.epochs * len(c.text.train) // (c.batch_size * c.seq_len)
    # Number of warmup optimization steps
V
gpt  
Varuna Jayasiri 已提交
216 217 218 219 220 221 222 223 224 225
    optimizer.warmup = c.warmup_steps // (c.batch_size * c.seq_len)

    return optimizer


def main():
    # Create experiment
    experiment.create(name="gpt")
    # Create configs
    conf = Configs()
V
Varuna Jayasiri 已提交
226
    # Override configurations
V
Varuna Jayasiri 已提交
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
    experiment.configs(conf, {
        # Use character level tokenizer
        'tokenizer': 'character',
        # Prompt separator is blank
        'prompt_separator': '',
        # Starting prompt for sampling
        'prompt': 'It is ',
        # Use Tiny Shakespeare dataset
        'text': 'tiny_shakespeare',

        # Use a context size of $128$
        'seq_len': 128,
        # Train for $32$ epochs
        'epochs': 32,
        # Batch size $128$
        'batch_size': 128,
        # Switch between training and validation for $10$ times
        # per epoch
        'inner_iterations': 10,

        # Transformer configurations
        'transformer.d_model': 512,
        'transformer.d_ff': 2048,
        'transformer.n_heads': 8,
        'transformer.n_layers': 6
    })
V
gpt  
Varuna Jayasiri 已提交
253 254 255 256 257 258

    # Set models for saving and loading
    experiment.add_pytorch_models({'model': conf.model})

    # Start the experiment
    with experiment.start():
V
Varuna Jayasiri 已提交
259
        # Run training
V
gpt  
Varuna Jayasiri 已提交
260 261 262
        conf.run()


V
Varuna Jayasiri 已提交
263
#
V
gpt  
Varuna Jayasiri 已提交
264 265
if __name__ == '__main__':
    main()