提交 51230ba0 编写于 作者: U u010280923

opt ppo model

上级 01d9b9c6
......@@ -364,8 +364,6 @@ def load_prompt_data_4_ppo(args):
prompt = row["prompt"]
prompt_idx = tokenizer.tokenizer.encode(prompt)
prompt_idx = prompt_idx[: req_len]
prompt_idx = prompt_idx + [0] * (req_len - len(prompt_idx))
prompt_token_ids.append(
torch.tensor(prompt_idx, dtype=torch.long))
......
......@@ -521,7 +521,9 @@ class RWKV(pl.LightningModule):
sample_num_times = max(1, seq_len - prompt.shape[-1])
for _ in tqdm(range(sample_num_times), desc="gen responses"):
logits, embeds = self.forward(out, ppo_train=True)
pad_idx = torch.tensor([[eos_token] * (self.args.n_embd - out.shape[-1])])
query_idx = torch.cat((out, pad_idx), dim=-1)
logits, embeds = self.forward(query_idx, ppo_train=True)
logits, embeds = logits[:, -1], embeds[:, -1]
if exists(filter_logits_fn):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册