From 5b3f373ae1b11b2e7b971ef598f38cd86caa607f Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 18:54:00 +0800 Subject: [PATCH] bug fixed --- src/rlhf/reward.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 52307b7..12b2e77 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -129,7 +129,8 @@ class RewardModel(pl.LightningModule): # if self.deepspeed_offload: # return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) # return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) - return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + return torch.optim.Adam(optim_groups, lr=1e-5, betas=(0.9, 0.95)) @property def deepspeed_offload(self) -> bool: -- GitLab