From a6502de621f2690085a3cfd09dad305ab00b5540 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sat, 26 Dec 2020 20:46:58 +0530 Subject: [PATCH] shuffle data --- labml_nn/hypernetworks/experiment.py | 35 +++++++++++++++++++--------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/labml_nn/hypernetworks/experiment.py b/labml_nn/hypernetworks/experiment.py index 780ad7e3..c1808178 100644 --- a/labml_nn/hypernetworks/experiment.py +++ b/labml_nn/hypernetworks/experiment.py @@ -6,11 +6,12 @@ from labml import lab, experiment, monit, tracker, logger from labml.configs import option from labml.logger import Text from labml.utils.pytorch import get_modules -from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset +from labml_helpers.datasets.text import TextDataset, TextFileDataset, SequentialUnBatchedDataset from labml_helpers.metrics.accuracy import Accuracy from labml_helpers.module import Module from labml_helpers.optimizer import OptimizerConfigs from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex +from torch.utils.data import DataLoader from labml_nn.hypernetworks.hyper_lstm import HyperLSTM @@ -48,6 +49,14 @@ class CrossEntropyLoss(Module): return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1)) +def transpose_batch(batch): + transposed_data = list(zip(*batch)) + src = torch.stack(transposed_data[0], 1) + tgt = torch.stack(transposed_data[1], 1) + + return src, tgt + + class Configs(SimpleTrainValidConfigs): """ ## Configurations @@ -78,16 +87,20 @@ class Configs(SimpleTrainValidConfigs): self.optimizer = optimizer # Create a sequential data loader for training - self.train_loader = SequentialDataLoader(text=self.text.train, - dataset=self.text, - batch_size=self.batch_size, - seq_len=self.seq_len) + self.train_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.train, + dataset=self.text, + seq_len=self.seq_len), + batch_size=self.batch_size, + collate_fn=transpose_batch, + shuffle=True) # Create a sequential data loader for validation - self.valid_loader = SequentialDataLoader(text=self.text.valid, - dataset=self.text, - batch_size=self.batch_size, - seq_len=self.seq_len) + self.valid_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.valid, + dataset=self.text, + seq_len=self.seq_len), + batch_size=self.batch_size, + collate_fn=transpose_batch, + shuffle=True) self.state_modules = [self.accuracy] @@ -186,12 +199,12 @@ def main(): # A dictionary of configurations to override {'tokenizer': 'character', 'text': 'tiny_shakespeare', - 'optimizer.learning_rate': 1e-4, + 'optimizer.learning_rate': 2.5e-4, 'seq_len': 512, 'epochs': 128, 'batch_size': 2, - 'inner_iterations': 10}) + 'inner_iterations': 25}) # This is needed to initialize models conf.n_tokens = conf.text.n_tokens -- GitLab