diff --git a/src/dataset.py b/src/dataset.py index 585e47614ab88fd7914f01902371d991197628c1..170fa687c1656a55fffedb3b272136379c01a0ce 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -360,9 +360,8 @@ def load_prompt_data_4_ppo(args): pf = pd.read_csv(args.data_file) for index, row in pf.iterrows(): prompt = row["prompt"] - prompt_token_ids.append(tokenizer.tokenizer.encode(prompt)) - - prompt_token_ids = torch.tensor(prompt_token_ids, dtype=torch.long) + prompt_token_ids.append( + torch.tensor(tokenizer.tokenizer.encode(prompt), dtype=torch.long)) return prompt_token_ids diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 3fcbd7e5d438e7e6ae092e34c0676b19d22a6aa0..ae2ef5fa0b65c4a26a5244c0281b3e7336a1ece0 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -484,8 +484,8 @@ class RLHF(nn.Module): # and get the action (sampled sequence from rwkv as well as the action probs) # also calculate the reward using reward model and store # 随机挑选一条 prompt - rand_prompt_index = randrange(0, self.num_prompts) - state = self.prompt_token_ids[rand_prompt_index] + rand_prompt_index = randrange(0, len(prompts)) + state = prompts[rand_prompt_index] # remove padding from state state_mask = state != self.args.pad_value