diff --git a/src/dataset.py b/src/dataset.py index b3fab1b9973785ac47740d244afe6a1cedf9e645..399635dfb8d6e1c6dc9cc2be3e37b31eb9c45d24 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -364,8 +364,6 @@ def load_prompt_data_4_ppo(args): prompt = row["prompt"] prompt_idx = tokenizer.tokenizer.encode(prompt) prompt_idx = prompt_idx[: req_len] - prompt_idx = prompt_idx + [0] * (req_len - len(prompt_idx)) - prompt_token_ids.append( torch.tensor(prompt_idx, dtype=torch.long)) diff --git a/src/model.py b/src/model.py index ca8559b89d32d586620d1d4392ca3fb80c736fcf..8e8bdfdf93059b8d479c1463de5ebf2ee4716558 100644 --- a/src/model.py +++ b/src/model.py @@ -521,7 +521,9 @@ class RWKV(pl.LightningModule): sample_num_times = max(1, seq_len - prompt.shape[-1]) for _ in tqdm(range(sample_num_times), desc="gen responses"): - logits, embeds = self.forward(out, ppo_train=True) + pad_idx = torch.tensor([[eos_token] * (self.args.n_embd - out.shape[-1])]) + query_idx = torch.cat((out, pad_idx), dim=-1) + logits, embeds = self.forward(query_idx, ppo_train=True) logits, embeds = logits[:, -1], embeds[:, -1] if exists(filter_logits_fn):