diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 8e17cea7537567f911619378f23b26e41881db23..df7724448c03794cff15cd049850639bd667c881 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -85,12 +85,9 @@ class RewardModel(pl.LightningModule): ] def configure_optimizers(self): - # 论文中的参数:lr=1e-5, betas=(0.9, 0.95) - optimizer = torch.optim.Adam([ - {"rwkv_params": self.rwkv.parameters()}, - {"rm_params": self.parameters()} - ], lr=self.args.lr_init, betas=self.args.betas) - + # 论文中的参数: + optimizer = torch.optim.Adam(self.parameters(), lr=1e-5, betas=(0.9, 0.95) ) + # optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr_init, betas=self.args.betas) return optimizer