from typing import Callable, Any

import torch
import torch.nn as nn
from labml import lab, experiment, monit, tracker, logger
from labml.configs import option
from labml.logger import Text
from labml.utils.pytorch import get_modules
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset
from labml_helpers.metrics.accuracy import Accuracy
from labml_helpers.module import Module
from labml_helpers.optimizer import OptimizerConfigs
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex

from labml_nn.hypernetworks.hyper_lstm import HyperLSTM

class AutoregressiveModel(Module):
    ## Auto regressive model

    def __init__(self, n_vocab: int, d_model: int, n_rhn, n_z):
        # Token embedding module
        self.src_embed = nn.Embedding(n_vocab, d_model, n_rhn, n_z)
        self.lstm = HyperLSTM(d_model, d_model, n_rhn, n_z, 1)
        self.generator = nn.Linear(d_model, n_vocab)

    def __call__(self, x: torch.Tensor):
        x = self.src_embed(x)
        # Embed the tokens (`src`) and run it through the the transformer
        res, state = self.lstm(x)
        # Generate logits of the next token
        return self.generator(res), state

class CrossEntropyLoss(Module):
    Cross entropy loss

    def __init__(self):
        self.loss = nn.CrossEntropyLoss()

    def __call__(self, outputs, targets):
        return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))

class Configs(SimpleTrainValidConfigs):
    ## Configurations

    The default configs can and will be over-ridden when we start the experiment

    model: AutoregressiveModel
    text: TextDataset
    batch_size: int = 20
    seq_len: int = 512
    n_tokens: int
    tokenizer: Callable = 'character'

    is_save_models = True

    optimizer: torch.optim.Adam = 'transformer_optimizer'

    accuracy = Accuracy()
    loss_func = CrossEntropyLoss()

    def init(self):
        # Create a configurable optimizer.
        # Parameters like learning rate can be changed by passing a dictionary when starting the experiment.
        optimizer = OptimizerConfigs()
        optimizer.parameters = self.model.parameters()
        optimizer.optimizer = 'Adam'
        self.optimizer = optimizer

        # Create a sequential data loader for training
        self.train_loader = SequentialDataLoader(text=self.text.train,

        # Create a sequential data loader for validation
        self.valid_loader = SequentialDataLoader(text=self.text.valid,

        self.state_modules = [self.accuracy]

    def sample(self):
        Sampling function to generate samples periodically while training
        prompt = 'It is'
        log = [(prompt, Text.subtle)]
        # Sample 25 tokens
        for i in monit.iterate('Sample', 25):
            # Tokenize the prompt
            data = self.text.text_to_i(prompt).unsqueeze(-1)
            data = data.to(self.device)
            # Get the model output
            output, state = self.model(data)
            output = output.cpu()
            # Get the model prediction (greedy)
            output = output.argmax(dim=-1).squeeze()
            # Add the prediction to prompt
            prompt += self.text.itos[output[-1]]
            # Add the prediction for logging
            log += [(self.text.itos[output[-1]], Text.value)]


    def step(self, batch: Any, batch_idx: BatchIndex):
        This method is called for each batch

        # Get data and target labels
        data, target = batch[0].to(self.device), batch[1].to(self.device)

        if self.mode.is_train:
            tracker.add_global_step(data.shape[0] * data.shape[1])

        # Run the model
        output, state = self.model(data)

        # Calculate loss
        loss = self.loss_func(output, target)
        # Calculate accuracy
        self.accuracy(output, target)

        # Log the loss
        tracker.add("loss.", loss)

        #  If we are in training mode, calculate the gradients
        if self.mode.is_train:
            if batch_idx.is_last:
                tracker.add('model', self.model)


def character_tokenizer(x: str):
    return list(x)

def character():
    Character level tokenizer
    return character_tokenizer

def tiny_shakespeare(c: Configs):
    return TextFileDataset(
        lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer,

def autoregressive_model(c: Configs):
    Initialize the auto-regressive model
    m = AutoregressiveModel(c.n_tokens, 512, 16, 16)
    return m.to(c.device)

def main():
    # Create experiment
    experiment.create(name="knn_lm", comment='')
    # Create configs
    conf = Configs()
    # Load configurations
                       # A dictionary of configurations to override
                       {'tokenizer': 'character',
                        'text': 'tiny_shakespeare',
                        'optimizer.learning_rate': 1e-4,
Varuna Jayasiri 已提交
                        'seq_len': 512,
                        'epochs': 128,
                        'batch_size': 2,
                        'inner_iterations': 10})

    # This is needed to initialize models
    conf.n_tokens = conf.text.n_tokens

    # Set models for saving and loading

    # Start the experiment
    with experiment.start():
        # `TrainValidConfigs.run`

if __name__ == '__main__':