From a1fe37552df33b4a446270d89f481efd39d1d825 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 14:58:37 +0800 Subject: [PATCH] opt reward model --- src/dataset.py | 49 ++++++++++++++++++++++++++++++++++------------ src/rlhf/reward.py | 11 ++++++----- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/src/dataset.py b/src/dataset.py index 68b6807..001424c 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -267,6 +267,9 @@ class RMDataset(Dataset): "20B_tokenizer.json", "20B_tokenizer.json", ] # [vocab, vocab] for Pile model + self.prompt_mask_id = 0 + self.response_mask_id = 1 + self.padding_mask_id = 2 self.tokenizer = TOKENIZER(WORD_NAME) pf = pd.read_csv(args.data_file) @@ -275,10 +278,26 @@ class RMDataset(Dataset): prompt = row["prompt"] preferred = row["preferred"] alternate = row["alternate"] - preferred_sample = f"{prompt}\n{preferred}" - alternate_sample = f"{prompt}\n{alternate}" - data_list.append((self.tokenizer.tokenizer.encode(preferred_sample), - self.tokenizer.tokenizer.encode(alternate_sample))) + + prompt_idx = self.tokenizer.tokenizer.encode(prompt) + preferred_idx = self.tokenizer.tokenizer.encode(preferred) + alternate_idx = self.tokenizer.tokenizer.encode(alternate) + + prompt_mask = [self.padding_mask_id] * len(prompt_idx) + preferred_mask = [self.response_mask_id] * len(preferred_idx) + alternate_mask = [self.response_mask_id] * len(alternate_idx) + + prompt_prefer_idx = prompt_idx + preferred_idx + prompt_alter_idx = prompt_idx + alternate_idx + + prompt_prefer_mask = prompt_mask + preferred_mask + prompt_alter_mask = prompt_mask + alternate_mask + + data_list.append(( + prompt_prefer_idx, prompt_alter_idx, + prompt_prefer_mask, prompt_alter_mask + )) + self.data = data_list def __len__(self): @@ -287,15 +306,21 @@ class RMDataset(Dataset): def __getitem__(self, index): ctx_len = self.args.ctx_len req_len = ctx_len + 1 - preferred_sample, alternate_sample = self.data[index] + prompt_prefer_idx, prompt_alter_idx, prompt_prefer_mask, prompt_alter_mask = self.data[index] - preferred_sample = preferred_sample[: req_len] - alternate_sample = alternate_sample[: req_len] + prompt_prefer_idx = prompt_prefer_idx[: req_len] + prompt_alter_idx = prompt_alter_idx[: req_len] + prompt_prefer_mask = prompt_prefer_mask[: req_len] + prompt_alter_mask = prompt_alter_mask[: req_len] - preferred_sample = preferred_sample + [0] * (req_len - len(preferred_sample)) - alternate_sample = alternate_sample + [0] * (req_len - len(alternate_sample)) + prompt_prefer_idx = prompt_prefer_idx + [1] * (req_len - len(prompt_prefer_idx)) + prompt_alter_idx = prompt_alter_idx + [1] * (req_len - len(prompt_alter_idx)) + prompt_prefer_mask = prompt_prefer_mask + [self.padding_mask_id] * (req_len - len(prompt_prefer_mask)) + prompt_alter_mask = prompt_alter_mask + [self.padding_mask_id] * (req_len - len(prompt_alter_mask)) - x_p = torch.tensor(preferred_sample, dtype=torch.long) - x_a = torch.tensor(alternate_sample, dtype=torch.long) + x_p = torch.tensor(prompt_prefer_idx, dtype=torch.long) + x_a = torch.tensor(prompt_alter_idx, dtype=torch.long) + m_p = torch.tensor(prompt_prefer_mask, dtype=torch.long) + m_a = torch.tensor(prompt_alter_mask, dtype=torch.long) - return x_p, x_a \ No newline at end of file + return x_p, x_a, m_p, m_a \ No newline at end of file diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 559860f..8e17cea 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -65,6 +65,7 @@ class RewardModel(pl.LightningModule): # 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0 self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)) self.response_embed = nn.Parameter(torch.zeros(1, 1, dim)) + self.padding_embed = nn.Parameter(torch.zeros(1, 1, dim), requires_grad=False) # reward 得分计算 self.pred_reward = nn.Sequential( @@ -135,16 +136,16 @@ class RewardModel(pl.LightningModule): return reward - def forward(self, prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask): - prefer_reward = self.single_forward(prefer_x, prefer_x_prompt_mask) - alter_reward = self.single_forward(alter_x, alter_x_prompt_mask) + def forward(self, x_p, x_a, m_p, m_a): + prefer_reward = self.single_forward(x_p, prompt_mask=m_p) + alter_reward = self.single_forward(x_a, prompt_mask=m_a) return prefer_reward, alter_reward def training_step(self, batch, batch_idx): - prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask = batch + x_p, x_a, m_p, m_a = batch prefer_reward, alter_reward = self( - prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask) + x_p, x_a, m_p, m_a) loss = loss_function(prefer_reward, alter_reward) -- GitLab