From 51230ba006618279f792ec2ead192c76ca9a6e57 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Tue, 21 Mar 2023 00:11:58 +0800 Subject: [PATCH] opt ppo model --- src/dataset.py | 2 -- src/model.py | 4 +++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dataset.py b/src/dataset.py index b3fab1b..399635d 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -364,8 +364,6 @@ def load_prompt_data_4_ppo(args): prompt = row["prompt"] prompt_idx = tokenizer.tokenizer.encode(prompt) prompt_idx = prompt_idx[: req_len] - prompt_idx = prompt_idx + [0] * (req_len - len(prompt_idx)) - prompt_token_ids.append( torch.tensor(prompt_idx, dtype=torch.long)) diff --git a/src/model.py b/src/model.py index ca8559b..8e8bdfd 100644 --- a/src/model.py +++ b/src/model.py @@ -521,7 +521,9 @@ class RWKV(pl.LightningModule): sample_num_times = max(1, seq_len - prompt.shape[-1]) for _ in tqdm(range(sample_num_times), desc="gen responses"): - logits, embeds = self.forward(out, ppo_train=True) + pad_idx = torch.tensor([[eos_token] * (self.args.n_embd - out.shape[-1])]) + query_idx = torch.cat((out, pad_idx), dim=-1) + logits, embeds = self.forward(query_idx, ppo_train=True) logits, embeds = logits[:, -1], embeds[:, -1] if exists(filter_logits_fn): -- GitLab