From dfeee74605ebd68b54da9474ff247f2874ec6995 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 14:15:41 +0800 Subject: [PATCH] opt reward model --- src/rlhf/reward.py | 2 +- train_rm.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 1b1ba9d..559860f 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -68,7 +68,7 @@ class RewardModel(pl.LightningModule): # reward 得分计算 self.pred_reward = nn.Sequential( - nn.Linear(dim, 1), + nn.Linear(dim, 1, bias=False), Rearrange('... 1 -> ...') # 降维 ) diff --git a/train_rm.py b/train_rm.py index 94d2065..b0ef71a 100644 --- a/train_rm.py +++ b/train_rm.py @@ -240,7 +240,10 @@ if __name__ == "__main__": rm_model = RewardModel(args) # 训练 - trainer = Trainer.from_argparse_args() + trainer = Trainer.from_argparse_args( + args, + callbacks=[train_callback(args)], + ) if trainer.global_rank == 0: for n in rm_model.state_dict(): -- GitLab