diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 2a68d9e6525790fe574b82d4db2ae682635c3a8d..7e67ab94ca8f1b3aa7258113b70bca3206098f88 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -69,6 +69,11 @@ class RewardModel(pl.LightningModule): self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)) self.response_embed = nn.Parameter(torch.zeros(1, 1, dim)) self.padding_embed = nn.Parameter(torch.zeros(1, 1, dim), requires_grad=False) + self.prompt_response_mask_embed = torch.stack([ + self.prompt_embed, + self.response_embed, + self.padding_embed + ]) # reward 得分计算 self.pred_reward = nn.Sequential( @@ -161,11 +166,7 @@ class RewardModel(pl.LightningModule): # 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选 extra_embed = None if exists(prompt_mask): - extra_embed = torch.where( - rearrange(prompt_mask, 'b n -> b n 1'), - self.prompt_embed, - self.response_embed - ) + extra_embed = self.prompt_response_mask_embed[prompt_mask] # 获得最后一个 token 的 embedding last_token_embeds = self.rwkv(