diff --git a/src/dataset.py b/src/dataset.py index 001424cb5dfce365c3a6df11353412e6ad196d78..e65c312c3f2b3ba12533194d55f7e4446f1396ce 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -305,7 +305,7 @@ class RMDataset(Dataset): def __getitem__(self, index): ctx_len = self.args.ctx_len - req_len = ctx_len + 1 + req_len = ctx_len prompt_prefer_idx, prompt_alter_idx, prompt_prefer_mask, prompt_alter_mask = self.data[index] prompt_prefer_idx = prompt_prefer_idx[: req_len]