From 5f46304e33ad9dfdeaf59be774690fb58bf617e4 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 17:11:03 +0800 Subject: [PATCH] opt reward model --- src/rlhf/reward.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 2a68d9e..7e67ab9 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -69,6 +69,11 @@ class RewardModel(pl.LightningModule): self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)) self.response_embed = nn.Parameter(torch.zeros(1, 1, dim)) self.padding_embed = nn.Parameter(torch.zeros(1, 1, dim), requires_grad=False) + self.prompt_response_mask_embed = torch.stack([ + self.prompt_embed, + self.response_embed, + self.padding_embed + ]) # reward 得分计算 self.pred_reward = nn.Sequential( @@ -161,11 +166,7 @@ class RewardModel(pl.LightningModule): # 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选 extra_embed = None if exists(prompt_mask): - extra_embed = torch.where( - rearrange(prompt_mask, 'b n -> b n 1'), - self.prompt_embed, - self.response_embed - ) + extra_embed = self.prompt_response_mask_embed[prompt_mask] # 获得最后一个 token 的 embedding last_token_embeds = self.rwkv( -- GitLab