提交 3859a713 编写于 作者: U u010280923

opt reward model

上级 ddbbe006
......@@ -119,11 +119,10 @@ class RewardModel(pl.LightningModule):
state=None,
extra_embed=extra_embed,
rm_train=True
)
)[:, -1, :]
# 所有的 token 向量求平均,并输入到打分模块进行打分
pooled = masked_mean(last_token_embeds, mask, dim = 1)
reward = self.pred_reward(pooled)
reward = self.pred_reward(last_token_embeds)
return reward
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册