diff --git a/src/model.py b/src/model.py index 8e8bdfdf93059b8d479c1463de5ebf2ee4716558..9322ba1c4ccf47b239823ee61dd80ee6e80560fa 100644 --- a/src/model.py +++ b/src/model.py @@ -521,7 +521,7 @@ class RWKV(pl.LightningModule): sample_num_times = max(1, seq_len - prompt.shape[-1]) for _ in tqdm(range(sample_num_times), desc="gen responses"): - pad_idx = torch.tensor([[eos_token] * (self.args.n_embd - out.shape[-1])]) + pad_idx = torch.tensor([[eos_token] * (self.args.ctx_len - out.shape[-1])]) query_idx = torch.cat((out, pad_idx), dim=-1) logits, embeds = self.forward(query_idx, ppo_train=True) logits, embeds = logits[:, -1], embeds[:, -1]