提交 2164e3e5 编写于 作者: U u010280923

opt ppo model

上级 51230ba0
......@@ -521,7 +521,7 @@ class RWKV(pl.LightningModule):
sample_num_times = max(1, seq_len - prompt.shape[-1])
for _ in tqdm(range(sample_num_times), desc="gen responses"):
pad_idx = torch.tensor([[eos_token] * (self.args.n_embd - out.shape[-1])])
pad_idx = torch.tensor([[eos_token] * (self.args.ctx_len - out.shape[-1])])
query_idx = torch.cat((out, pad_idx), dim=-1)
logits, embeds = self.forward(query_idx, ppo_train=True)
logits, embeds = logits[:, -1], embeds[:, -1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册