提交 01d9b9c6 编写于 作者: U u010280923

opt ppo model

上级 8aabd2fc
......@@ -356,12 +356,18 @@ def load_prompt_data_4_ppo(args):
] # [vocab, vocab] for Pile model
tokenizer = TOKENIZER(WORD_NAME)
ctx_len = args.ctx_len
req_len = ctx_len
pf = pd.read_csv(args.data_file)
for index, row in pf.iterrows():
prompt = row["prompt"]
prompt_idx = tokenizer.tokenizer.encode(prompt)
prompt_idx = prompt_idx[: req_len]
prompt_idx = prompt_idx + [0] * (req_len - len(prompt_idx))
prompt_token_ids.append(
torch.tensor(tokenizer.tokenizer.encode(prompt), dtype=torch.long))
torch.tensor(prompt_idx, dtype=torch.long))
return prompt_token_ids
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册