diff --git a/src/dataset.py b/src/dataset.py index 170fa687c1656a55fffedb3b272136379c01a0ce..b3fab1b9973785ac47740d244afe6a1cedf9e645 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -356,12 +356,18 @@ def load_prompt_data_4_ppo(args): ] # [vocab, vocab] for Pile model tokenizer = TOKENIZER(WORD_NAME) + ctx_len = args.ctx_len + req_len = ctx_len pf = pd.read_csv(args.data_file) for index, row in pf.iterrows(): prompt = row["prompt"] + prompt_idx = tokenizer.tokenizer.encode(prompt) + prompt_idx = prompt_idx[: req_len] + prompt_idx = prompt_idx + [0] * (req_len - len(prompt_idx)) + prompt_token_ids.append( - torch.tensor(tokenizer.tokenizer.encode(prompt), dtype=torch.long)) + torch.tensor(prompt_idx, dtype=torch.long)) return prompt_token_ids