提交 a6502de6 编写于 作者: V Varuna Jayasiri

shuffle data

上级 716dda5f
......@@ -6,11 +6,12 @@ from labml import lab, experiment, monit, tracker, logger
from labml.configs import option
from labml.logger import Text
from labml.utils.pytorch import get_modules
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset
from labml_helpers.datasets.text import TextDataset, TextFileDataset, SequentialUnBatchedDataset
from labml_helpers.metrics.accuracy import Accuracy
from labml_helpers.module import Module
from labml_helpers.optimizer import OptimizerConfigs
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
from torch.utils.data import DataLoader
from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
......@@ -48,6 +49,14 @@ class CrossEntropyLoss(Module):
return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
def transpose_batch(batch):
transposed_data = list(zip(*batch))
src = torch.stack(transposed_data[0], 1)
tgt = torch.stack(transposed_data[1], 1)
return src, tgt
class Configs(SimpleTrainValidConfigs):
"""
## Configurations
......@@ -78,16 +87,20 @@ class Configs(SimpleTrainValidConfigs):
self.optimizer = optimizer
# Create a sequential data loader for training
self.train_loader = SequentialDataLoader(text=self.text.train,
dataset=self.text,
batch_size=self.batch_size,
seq_len=self.seq_len)
self.train_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.train,
dataset=self.text,
seq_len=self.seq_len),
batch_size=self.batch_size,
collate_fn=transpose_batch,
shuffle=True)
# Create a sequential data loader for validation
self.valid_loader = SequentialDataLoader(text=self.text.valid,
dataset=self.text,
batch_size=self.batch_size,
seq_len=self.seq_len)
self.valid_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.valid,
dataset=self.text,
seq_len=self.seq_len),
batch_size=self.batch_size,
collate_fn=transpose_batch,
shuffle=True)
self.state_modules = [self.accuracy]
......@@ -186,12 +199,12 @@ def main():
# A dictionary of configurations to override
{'tokenizer': 'character',
'text': 'tiny_shakespeare',
'optimizer.learning_rate': 1e-4,
'optimizer.learning_rate': 2.5e-4,
'seq_len': 512,
'epochs': 128,
'batch_size': 2,
'inner_iterations': 10})
'inner_iterations': 25})
# This is needed to initialize models
conf.n_tokens = conf.text.n_tokens
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册