提交 01d9b9c6 编写于 作者: U u010280923

opt ppo model

上级 8aabd2fc
...@@ -356,12 +356,18 @@ def load_prompt_data_4_ppo(args): ...@@ -356,12 +356,18 @@ def load_prompt_data_4_ppo(args):
] # [vocab, vocab] for Pile model ] # [vocab, vocab] for Pile model
tokenizer = TOKENIZER(WORD_NAME) tokenizer = TOKENIZER(WORD_NAME)
ctx_len = args.ctx_len
req_len = ctx_len
pf = pd.read_csv(args.data_file) pf = pd.read_csv(args.data_file)
for index, row in pf.iterrows(): for index, row in pf.iterrows():
prompt = row["prompt"] prompt = row["prompt"]
prompt_idx = tokenizer.tokenizer.encode(prompt)
prompt_idx = prompt_idx[: req_len]
prompt_idx = prompt_idx + [0] * (req_len - len(prompt_idx))
prompt_token_ids.append( prompt_token_ids.append(
torch.tensor(tokenizer.tokenizer.encode(prompt), dtype=torch.long)) torch.tensor(prompt_idx, dtype=torch.long))
return prompt_token_ids return prompt_token_ids
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册