提交 a1fe3755 编写于 作者: U u010280923

opt reward model

上级 dfeee746
......@@ -267,6 +267,9 @@ class RMDataset(Dataset):
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
self.prompt_mask_id = 0
self.response_mask_id = 1
self.padding_mask_id = 2
self.tokenizer = TOKENIZER(WORD_NAME)
pf = pd.read_csv(args.data_file)
......@@ -275,10 +278,26 @@ class RMDataset(Dataset):
prompt = row["prompt"]
preferred = row["preferred"]
alternate = row["alternate"]
preferred_sample = f"{prompt}\n{preferred}"
alternate_sample = f"{prompt}\n{alternate}"
data_list.append((self.tokenizer.tokenizer.encode(preferred_sample),
self.tokenizer.tokenizer.encode(alternate_sample)))
prompt_idx = self.tokenizer.tokenizer.encode(prompt)
preferred_idx = self.tokenizer.tokenizer.encode(preferred)
alternate_idx = self.tokenizer.tokenizer.encode(alternate)
prompt_mask = [self.padding_mask_id] * len(prompt_idx)
preferred_mask = [self.response_mask_id] * len(preferred_idx)
alternate_mask = [self.response_mask_id] * len(alternate_idx)
prompt_prefer_idx = prompt_idx + preferred_idx
prompt_alter_idx = prompt_idx + alternate_idx
prompt_prefer_mask = prompt_mask + preferred_mask
prompt_alter_mask = prompt_mask + alternate_mask
data_list.append((
prompt_prefer_idx, prompt_alter_idx,
prompt_prefer_mask, prompt_alter_mask
))
self.data = data_list
def __len__(self):
......@@ -287,15 +306,21 @@ class RMDataset(Dataset):
def __getitem__(self, index):
ctx_len = self.args.ctx_len
req_len = ctx_len + 1
preferred_sample, alternate_sample = self.data[index]
prompt_prefer_idx, prompt_alter_idx, prompt_prefer_mask, prompt_alter_mask = self.data[index]
preferred_sample = preferred_sample[: req_len]
alternate_sample = alternate_sample[: req_len]
prompt_prefer_idx = prompt_prefer_idx[: req_len]
prompt_alter_idx = prompt_alter_idx[: req_len]
prompt_prefer_mask = prompt_prefer_mask[: req_len]
prompt_alter_mask = prompt_alter_mask[: req_len]
preferred_sample = preferred_sample + [0] * (req_len - len(preferred_sample))
alternate_sample = alternate_sample + [0] * (req_len - len(alternate_sample))
prompt_prefer_idx = prompt_prefer_idx + [1] * (req_len - len(prompt_prefer_idx))
prompt_alter_idx = prompt_alter_idx + [1] * (req_len - len(prompt_alter_idx))
prompt_prefer_mask = prompt_prefer_mask + [self.padding_mask_id] * (req_len - len(prompt_prefer_mask))
prompt_alter_mask = prompt_alter_mask + [self.padding_mask_id] * (req_len - len(prompt_alter_mask))
x_p = torch.tensor(preferred_sample, dtype=torch.long)
x_a = torch.tensor(alternate_sample, dtype=torch.long)
x_p = torch.tensor(prompt_prefer_idx, dtype=torch.long)
x_a = torch.tensor(prompt_alter_idx, dtype=torch.long)
m_p = torch.tensor(prompt_prefer_mask, dtype=torch.long)
m_a = torch.tensor(prompt_alter_mask, dtype=torch.long)
return x_p, x_a
\ No newline at end of file
return x_p, x_a, m_p, m_a
\ No newline at end of file
......@@ -65,6 +65,7 @@ class RewardModel(pl.LightningModule):
# 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0
self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
self.padding_embed = nn.Parameter(torch.zeros(1, 1, dim), requires_grad=False)
# reward 得分计算
self.pred_reward = nn.Sequential(
......@@ -135,16 +136,16 @@ class RewardModel(pl.LightningModule):
return reward
def forward(self, prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask):
prefer_reward = self.single_forward(prefer_x, prefer_x_prompt_mask)
alter_reward = self.single_forward(alter_x, alter_x_prompt_mask)
def forward(self, x_p, x_a, m_p, m_a):
prefer_reward = self.single_forward(x_p, prompt_mask=m_p)
alter_reward = self.single_forward(x_a, prompt_mask=m_a)
return prefer_reward, alter_reward
def training_step(self, batch, batch_idx):
prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask = batch
x_p, x_a, m_p, m_a = batch
prefer_reward, alter_reward = self(
prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask)
x_p, x_a, m_p, m_a)
loss = loss_function(prefer_reward, alter_reward)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册