提交 3b7e41b2 编写于 作者: U u010280923

bug fixed

上级 30ccef27
...@@ -360,9 +360,8 @@ def load_prompt_data_4_ppo(args): ...@@ -360,9 +360,8 @@ def load_prompt_data_4_ppo(args):
pf = pd.read_csv(args.data_file) pf = pd.read_csv(args.data_file)
for index, row in pf.iterrows(): for index, row in pf.iterrows():
prompt = row["prompt"] prompt = row["prompt"]
prompt_token_ids.append(tokenizer.tokenizer.encode(prompt)) prompt_token_ids.append(
torch.tensor(tokenizer.tokenizer.encode(prompt), dtype=torch.long))
prompt_token_ids = torch.tensor(prompt_token_ids, dtype=torch.long)
return prompt_token_ids return prompt_token_ids
......
...@@ -484,8 +484,8 @@ class RLHF(nn.Module): ...@@ -484,8 +484,8 @@ class RLHF(nn.Module):
# and get the action (sampled sequence from rwkv as well as the action probs) # and get the action (sampled sequence from rwkv as well as the action probs)
# also calculate the reward using reward model and store # also calculate the reward using reward model and store
# 随机挑选一条 prompt # 随机挑选一条 prompt
rand_prompt_index = randrange(0, self.num_prompts) rand_prompt_index = randrange(0, len(prompts))
state = self.prompt_token_ids[rand_prompt_index] state = prompts[rand_prompt_index]
# remove padding from state # remove padding from state
state_mask = state != self.args.pad_value state_mask = state != self.args.pad_value
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册