From 2164e3e5e9595834261c72218f79405f1f5a4509 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Tue, 21 Mar 2023 00:18:01 +0800 Subject: [PATCH] opt ppo model --- src/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.py b/src/model.py index 8e8bdfd..9322ba1 100644 --- a/src/model.py +++ b/src/model.py @@ -521,7 +521,7 @@ class RWKV(pl.LightningModule): sample_num_times = max(1, seq_len - prompt.shape[-1]) for _ in tqdm(range(sample_num_times), desc="gen responses"): - pad_idx = torch.tensor([[eos_token] * (self.args.n_embd - out.shape[-1])]) + pad_idx = torch.tensor([[eos_token] * (self.args.ctx_len - out.shape[-1])]) query_idx = torch.cat((out, pad_idx), dim=-1) logits, embeds = self.forward(query_idx, ppo_train=True) logits, embeds = logits[:, -1], embeds[:, -1] -- GitLab