diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 12b2e77dd69632cd13f6491143892567a59a7d95..0b8f6520dadefd1341bc28aeacebed288c6bad59 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -126,11 +126,10 @@ class RewardModel(pl.LightningModule): {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, ] - # if self.deepspeed_offload: - # return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) - # return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + if self.deepspeed_offload: + return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) + return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) - return torch.optim.Adam(optim_groups, lr=1e-5, betas=(0.9, 0.95)) @property def deepspeed_offload(self) -> bool: