diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 3887da240a3f8c973eed0d2814f893ee49df06ae..b9797d1a3e0d9b89ff01ebb0cae10eb206bb0de4 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -171,7 +171,6 @@ class RewardModel(pl.LightningModule): # 获得最后一个 token 的 embedding last_token_embeds = self.rwkv( x, - state=None, extra_embed=extra_embed, rm_train=True )[:, -1, :]