diff --git a/src/dataset.py b/src/dataset.py index 8057d6ef9c2957a93f07e3dcc9e6b4c916fa99fe..29188ec151d5decd957593bae3efd83bf9894d71 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -243,17 +243,17 @@ class S2SDataset(Dataset): def __getitem__(self, index): ctx_len = self.args.ctx_len + req_len = ctx_len + 1 question, answer = self.data[index] text = question + answer - text = text[:ctx_len] + text = text[:req_len] - text = text + [0] * (ctx_len - len(text)) - x = torch.tensor(text, dtype=torch.long) + text = text + [0] * (req_len - len(text)) + x = torch.tensor(text[:-1], dtype=torch.long) - answer = answer + [0] * (ctx_len - len(answer)) - y = torch.tensor(answer, dtype=torch.long) + y = torch.tensor(text[1:], dtype=torch.long) - z = [1] * len(question) + [0] * (ctx_len - len(question)) + z = [0] * (len(question) - 1) + [1] * (ctx_len - (len(question) - 1)) z = torch.tensor(z, dtype=torch.long) return x, y, z