experiment.py 6.3 KB
Newer Older
V
Varuna Jayasiri 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
from typing import Callable, Any

import torch
import torch.nn as nn
from labml import lab, experiment, monit, tracker, logger
from labml.configs import option
from labml.logger import Text
from labml.utils.pytorch import get_modules
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset
from labml_helpers.metrics.accuracy import Accuracy
from labml_helpers.module import Module
from labml_helpers.optimizer import OptimizerConfigs
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex

from labml_nn.hypernetworks.hyper_lstm import HyperLSTM


class AutoregressiveModel(Module):
    """
    ## Auto regressive model
    """

    def __init__(self, n_vocab: int, d_model: int, n_rhn, n_z):
        super().__init__()
        # Token embedding module
        self.src_embed = nn.Embedding(n_vocab, d_model, n_rhn, n_z)
        self.lstm = HyperLSTM(d_model, d_model, n_rhn, n_z, 1)
        self.generator = nn.Linear(d_model, n_vocab)

    def __call__(self, x: torch.Tensor):
        x = self.src_embed(x)
        # Embed the tokens (`src`) and run it through the the transformer
        res, state = self.lstm(x)
        # Generate logits of the next token
        return self.generator(res), state


class CrossEntropyLoss(Module):
    """
    Cross entropy loss
    """

    def __init__(self):
        super().__init__()
        self.loss = nn.CrossEntropyLoss()

    def __call__(self, outputs, targets):
        return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))


class Configs(SimpleTrainValidConfigs):
    """
    ## Configurations

    The default configs can and will be over-ridden when we start the experiment
    """

    model: AutoregressiveModel
    text: TextDataset
    batch_size: int = 20
    seq_len: int = 512
    n_tokens: int
    tokenizer: Callable = 'character'

    is_save_models = True

    optimizer: torch.optim.Adam = 'transformer_optimizer'

    accuracy = Accuracy()
    loss_func = CrossEntropyLoss()

    def init(self):
        # Create a configurable optimizer.
        # Parameters like learning rate can be changed by passing a dictionary when starting the experiment.
        optimizer = OptimizerConfigs()
        optimizer.parameters = self.model.parameters()
        optimizer.optimizer = 'Adam'
        self.optimizer = optimizer

        # Create a sequential data loader for training
        self.train_loader = SequentialDataLoader(text=self.text.train,
                                                 dataset=self.text,
                                                 batch_size=self.batch_size,
                                                 seq_len=self.seq_len)

        # Create a sequential data loader for validation
        self.valid_loader = SequentialDataLoader(text=self.text.valid,
                                                 dataset=self.text,
                                                 batch_size=self.batch_size,
                                                 seq_len=self.seq_len)

        self.state_modules = [self.accuracy]

    def sample(self):
        """
        Sampling function to generate samples periodically while training
        """
        prompt = 'It is'
        log = [(prompt, Text.subtle)]
        # Sample 25 tokens
        for i in monit.iterate('Sample', 25):
            # Tokenize the prompt
            data = self.text.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            # Get the model output
            output, state = self.model(data)
            output = output.cpu()
            # Get the model prediction (greedy)
            output = output.argmax(dim=-1).squeeze()
            # Add the prediction to prompt
            prompt += self.text.itos[output[-1]]
            # Add the prediction for logging
            log += [(self.text.itos[output[-1]], Text.value)]

        logger.log(log)

    def step(self, batch: Any, batch_idx: BatchIndex):
        """
        This method is called for each batch
        """
        self.model.train(self.mode.is_train)

        # Get data and target labels
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        if self.mode.is_train:
            tracker.add_global_step(data.shape[0] * data.shape[1])

        # Run the model
        output, state = self.model(data)

        # Calculate loss
        loss = self.loss_func(output, target)
        # Calculate accuracy
        self.accuracy(output, target)

        # Log the loss
        tracker.add("loss.", loss)

        #  If we are in training mode, calculate the gradients
        if self.mode.is_train:
            loss.backward()
            self.optimizer.step()
            if batch_idx.is_last:
                tracker.add('model', self.model)
            self.optimizer.zero_grad()

        tracker.save()


def character_tokenizer(x: str):
    return list(x)


@option(Configs.tokenizer)
def character():
    """
    Character level tokenizer
    """
    return character_tokenizer


@option(Configs.text)
def tiny_shakespeare(c: Configs):
    return TextFileDataset(
        lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer,
        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')


@option(Configs.model)
def autoregressive_model(c: Configs):
    """
    Initialize the auto-regressive model
    """
    m = AutoregressiveModel(c.n_tokens, 512, 16, 16)
    return m.to(c.device)


def main():
    # Create experiment
    experiment.create(name="knn_lm", comment='')
    # Create configs
    conf = Configs()
    # Load configurations
    experiment.configs(conf,
                       # A dictionary of configurations to override
                       {'tokenizer': 'character',
                        'text': 'tiny_shakespeare',
V
fix lr  
Varuna Jayasiri 已提交
189
                        'optimizer.learning_rate': 1e-4,
V
Varuna Jayasiri 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210

                        'seq_len': 512,
                        'epochs': 128,
                        'batch_size': 2,
                        'inner_iterations': 10})

    # This is needed to initialize models
    conf.n_tokens = conf.text.n_tokens

    # Set models for saving and loading
    experiment.add_pytorch_models(get_modules(conf))

    conf.init()
    # Start the experiment
    with experiment.start():
        # `TrainValidConfigs.run`
        conf.run()


if __name__ == '__main__':
    main()