From f7516c41912be40f24cb72ed9b6498b5130ef5fb Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 22:48:53 +0800 Subject: [PATCH] opt reward model --- src/rlhf/reward.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 12b2e77..0b8f652 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -126,11 +126,10 @@ class RewardModel(pl.LightningModule): {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, ] - # if self.deepspeed_offload: - # return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) - # return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + if self.deepspeed_offload: + return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) + return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) - return torch.optim.Adam(optim_groups, lr=1e-5, betas=(0.9, 0.95)) @property def deepspeed_offload(self) -> bool: -- GitLab