提交 3b7e41b2 编写于 作者: U u010280923

bug fixed

上级 30ccef27
......@@ -360,9 +360,8 @@ def load_prompt_data_4_ppo(args):
pf = pd.read_csv(args.data_file)
for index, row in pf.iterrows():
prompt = row["prompt"]
prompt_token_ids.append(tokenizer.tokenizer.encode(prompt))
prompt_token_ids = torch.tensor(prompt_token_ids, dtype=torch.long)
prompt_token_ids.append(
torch.tensor(tokenizer.tokenizer.encode(prompt), dtype=torch.long))
return prompt_token_ids
......
......@@ -484,8 +484,8 @@ class RLHF(nn.Module):
# and get the action (sampled sequence from rwkv as well as the action probs)
# also calculate the reward using reward model and store
# 随机挑选一条 prompt
rand_prompt_index = randrange(0, self.num_prompts)
state = self.prompt_token_ids[rand_prompt_index]
rand_prompt_index = randrange(0, len(prompts))
state = prompts[rand_prompt_index]
# remove padding from state
state_mask = state != self.args.pad_value
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册