From 01d9b9c6ad5e084ecf46dfc5d48183851e599f09 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 20 Mar 2023 23:37:06 +0800 Subject: [PATCH] opt ppo model --- src/dataset.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/dataset.py b/src/dataset.py index 170fa68..b3fab1b 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -356,12 +356,18 @@ def load_prompt_data_4_ppo(args): ] # [vocab, vocab] for Pile model tokenizer = TOKENIZER(WORD_NAME) + ctx_len = args.ctx_len + req_len = ctx_len pf = pd.read_csv(args.data_file) for index, row in pf.iterrows(): prompt = row["prompt"] + prompt_idx = tokenizer.tokenizer.encode(prompt) + prompt_idx = prompt_idx[: req_len] + prompt_idx = prompt_idx + [0] * (req_len - len(prompt_idx)) + prompt_token_ids.append( - torch.tensor(tokenizer.tokenizer.encode(prompt), dtype=torch.long)) + torch.tensor(prompt_idx, dtype=torch.long)) return prompt_token_ids -- GitLab