From 3859a7137aee699025e1df9a6329b0a62976507c Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 15:32:47 +0800 Subject: [PATCH] opt reward model --- src/rlhf/reward.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 665188b..eeafe74 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -119,11 +119,10 @@ class RewardModel(pl.LightningModule): state=None, extra_embed=extra_embed, rm_train=True - ) + )[:, -1, :] # 所有的 token 向量求平均,并输入到打分模块进行打分 - pooled = masked_mean(last_token_embeds, mask, dim = 1) - reward = self.pred_reward(pooled) + reward = self.pred_reward(last_token_embeds) return reward -- GitLab