diff --git a/src/dataset.py b/src/dataset.py
index 8057d6ef9c2957a93f07e3dcc9e6b4c916fa99fe..29188ec151d5decd957593bae3efd83bf9894d71 100644
--- a/src/dataset.py
+++ b/src/dataset.py
@@ -243,17 +243,17 @@ class S2SDataset(Dataset):
 
     def __getitem__(self, index):
         ctx_len = self.args.ctx_len
+        req_len = ctx_len + 1
         question, answer = self.data[index]
         text = question + answer
-        text = text[:ctx_len]
+        text = text[:req_len]
 
-        text = text + [0] * (ctx_len - len(text))
-        x = torch.tensor(text, dtype=torch.long)
+        text = text + [0] * (req_len - len(text))
+        x = torch.tensor(text[:-1], dtype=torch.long)
 
-        answer = answer + [0] * (ctx_len - len(answer))
-        y = torch.tensor(answer, dtype=torch.long)
+        y = torch.tensor(text[1:], dtype=torch.long)
 
-        z = [1] * len(question) + [0] * (ctx_len - len(question))
+        z = [0] * (len(question) - 1) + [1] * (ctx_len - (len(question) - 1))
         z = torch.tensor(z, dtype=torch.long)
 
         return x, y, z