提交 a6502de6 编写于 作者: V Varuna Jayasiri

shuffle data

上级 716dda5f
...@@ -6,11 +6,12 @@ from labml import lab, experiment, monit, tracker, logger ...@@ -6,11 +6,12 @@ from labml import lab, experiment, monit, tracker, logger
from labml.configs import option from labml.configs import option
from labml.logger import Text from labml.logger import Text
from labml.utils.pytorch import get_modules from labml.utils.pytorch import get_modules
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset from labml_helpers.datasets.text import TextDataset, TextFileDataset, SequentialUnBatchedDataset
from labml_helpers.metrics.accuracy import Accuracy from labml_helpers.metrics.accuracy import Accuracy
from labml_helpers.module import Module from labml_helpers.module import Module
from labml_helpers.optimizer import OptimizerConfigs from labml_helpers.optimizer import OptimizerConfigs
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
from torch.utils.data import DataLoader
from labml_nn.hypernetworks.hyper_lstm import HyperLSTM from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
...@@ -48,6 +49,14 @@ class CrossEntropyLoss(Module): ...@@ -48,6 +49,14 @@ class CrossEntropyLoss(Module):
return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1)) return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
def transpose_batch(batch):
transposed_data = list(zip(*batch))
src = torch.stack(transposed_data[0], 1)
tgt = torch.stack(transposed_data[1], 1)
return src, tgt
class Configs(SimpleTrainValidConfigs): class Configs(SimpleTrainValidConfigs):
""" """
## Configurations ## Configurations
...@@ -78,16 +87,20 @@ class Configs(SimpleTrainValidConfigs): ...@@ -78,16 +87,20 @@ class Configs(SimpleTrainValidConfigs):
self.optimizer = optimizer self.optimizer = optimizer
# Create a sequential data loader for training # Create a sequential data loader for training
self.train_loader = SequentialDataLoader(text=self.text.train, self.train_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.train,
dataset=self.text, dataset=self.text,
batch_size=self.batch_size, seq_len=self.seq_len),
seq_len=self.seq_len) batch_size=self.batch_size,
collate_fn=transpose_batch,
shuffle=True)
# Create a sequential data loader for validation # Create a sequential data loader for validation
self.valid_loader = SequentialDataLoader(text=self.text.valid, self.valid_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.valid,
dataset=self.text, dataset=self.text,
batch_size=self.batch_size, seq_len=self.seq_len),
seq_len=self.seq_len) batch_size=self.batch_size,
collate_fn=transpose_batch,
shuffle=True)
self.state_modules = [self.accuracy] self.state_modules = [self.accuracy]
...@@ -186,12 +199,12 @@ def main(): ...@@ -186,12 +199,12 @@ def main():
# A dictionary of configurations to override # A dictionary of configurations to override
{'tokenizer': 'character', {'tokenizer': 'character',
'text': 'tiny_shakespeare', 'text': 'tiny_shakespeare',
'optimizer.learning_rate': 1e-4, 'optimizer.learning_rate': 2.5e-4,
'seq_len': 512, 'seq_len': 512,
'epochs': 128, 'epochs': 128,
'batch_size': 2, 'batch_size': 2,
'inner_iterations': 10}) 'inner_iterations': 25})
# This is needed to initialize models # This is needed to initialize models
conf.n_tokens = conf.text.n_tokens conf.n_tokens = conf.text.n_tokens
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册