From 3b7e41b2bf3b1e99662881df1bd996558125b85f Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 17 Mar 2023 18:00:35 +0800 Subject: [PATCH] bug fixed --- src/dataset.py | 5 ++--- src/rlhf/ppo.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/dataset.py b/src/dataset.py index 585e476..170fa68 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -360,9 +360,8 @@ def load_prompt_data_4_ppo(args): pf = pd.read_csv(args.data_file) for index, row in pf.iterrows(): prompt = row["prompt"] - prompt_token_ids.append(tokenizer.tokenizer.encode(prompt)) - - prompt_token_ids = torch.tensor(prompt_token_ids, dtype=torch.long) + prompt_token_ids.append( + torch.tensor(tokenizer.tokenizer.encode(prompt), dtype=torch.long)) return prompt_token_ids diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 3fcbd7e..ae2ef5f 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -484,8 +484,8 @@ class RLHF(nn.Module): # and get the action (sampled sequence from rwkv as well as the action probs) # also calculate the reward using reward model and store # 随机挑选一条 prompt - rand_prompt_index = randrange(0, self.num_prompts) - state = self.prompt_token_ids[rand_prompt_index] + rand_prompt_index = randrange(0, len(prompts)) + state = prompts[rand_prompt_index] # remove padding from state state_mask = state != self.args.pad_value -- GitLab