diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 1b1ba9dc3125c180a8cab00c13ad606f96e91011..559860f4cfd87f2fb34d63abb93c2bfc8be3cd63 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -68,7 +68,7 @@ class RewardModel(pl.LightningModule): # reward 得分计算 self.pred_reward = nn.Sequential( - nn.Linear(dim, 1), + nn.Linear(dim, 1, bias=False), Rearrange('... 1 -> ...') # 降维 ) diff --git a/train_rm.py b/train_rm.py index 94d206540612c3d1e5cebc44c095fc095eaa680e..b0ef71aa64c5a5baac30265cf5b09bb1d9819e0a 100644 --- a/train_rm.py +++ b/train_rm.py @@ -240,7 +240,10 @@ if __name__ == "__main__": rm_model = RewardModel(args) # 训练 - trainer = Trainer.from_argparse_args() + trainer = Trainer.from_argparse_args( + args, + callbacks=[train_callback(args)], + ) if trainer.global_rank == 0: for n in rm_model.state_dict():