diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 7e67ab94ca8f1b3aa7258113b70bca3206098f88..3887da240a3f8c973eed0d2814f893ee49df06ae 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -66,14 +66,9 @@ class RewardModel(pl.LightningModule): dim = self.args.n_embd # 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0 - self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)) + self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)).to() self.response_embed = nn.Parameter(torch.zeros(1, 1, dim)) self.padding_embed = nn.Parameter(torch.zeros(1, 1, dim), requires_grad=False) - self.prompt_response_mask_embed = torch.stack([ - self.prompt_embed, - self.response_embed, - self.padding_embed - ]) # reward 得分计算 self.pred_reward = nn.Sequential( @@ -158,15 +153,20 @@ class RewardModel(pl.LightningModule): # derive prompt mask from prompt lengths if exists(prompt_lengths): batch, seq_len = x.shape - arange = torch.arange(seq_len, device = x.device) + arange = torch.arange(seq_len, device=x.device) prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1') # reward model should have an understanding of which section is prompt, and which section is response # 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值 # 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选 + prompt_response_mask_embed = torch.stack([ + self.prompt_embed, + self.response_embed, + self.padding_embed + ]).to(prompt_mask.device) extra_embed = None if exists(prompt_mask): - extra_embed = self.prompt_response_mask_embed[prompt_mask] + extra_embed = prompt_response_mask_embed[prompt_mask] # 获得最后一个 token 的 embedding last_token_embeds = self.rwkv(