From 6f1de05133856e5e50f4111fa95bf99d30feb7ca Mon Sep 17 00:00:00 2001 From: chenlong Date: Fri, 3 Mar 2023 21:18:17 +0800 Subject: [PATCH] fix bug --- src/dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/dataset.py b/src/dataset.py index 8057d6e..29188ec 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -243,17 +243,17 @@ class S2SDataset(Dataset): def __getitem__(self, index): ctx_len = self.args.ctx_len + req_len = ctx_len + 1 question, answer = self.data[index] text = question + answer - text = text[:ctx_len] + text = text[:req_len] - text = text + [0] * (ctx_len - len(text)) - x = torch.tensor(text, dtype=torch.long) + text = text + [0] * (req_len - len(text)) + x = torch.tensor(text[:-1], dtype=torch.long) - answer = answer + [0] * (ctx_len - len(answer)) - y = torch.tensor(answer, dtype=torch.long) + y = torch.tensor(text[1:], dtype=torch.long) - z = [1] * len(question) + [0] * (ctx_len - len(question)) + z = [0] * (len(question) - 1) + [1] * (ctx_len - (len(question) - 1)) z = torch.tensor(z, dtype=torch.long) return x, y, z -- GitLab