提交 6f1de051 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

fix bug

上级 63b543cf
......@@ -243,17 +243,17 @@ class S2SDataset(Dataset):
def __getitem__(self, index):
ctx_len = self.args.ctx_len
req_len = ctx_len + 1
question, answer = self.data[index]
text = question + answer
text = text[:ctx_len]
text = text[:req_len]
text = text + [0] * (ctx_len - len(text))
x = torch.tensor(text, dtype=torch.long)
text = text + [0] * (req_len - len(text))
x = torch.tensor(text[:-1], dtype=torch.long)
answer = answer + [0] * (ctx_len - len(answer))
y = torch.tensor(answer, dtype=torch.long)
y = torch.tensor(text[1:], dtype=torch.long)
z = [1] * len(question) + [0] * (ctx_len - len(question))
z = [0] * (len(question) - 1) + [1] * (ctx_len - (len(question) - 1))
z = torch.tensor(z, dtype=torch.long)
return x, y, z
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册