From aaa0d0a218bd3367295b4f6b386f2a04524201f3 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 18:45:15 +0800 Subject: [PATCH] bug fixed --- src/rlhf/reward.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 2318dd3..52307b7 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -12,6 +12,7 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam +from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam from pytorch_lightning.strategies import DeepSpeedStrategy from einops import rearrange, repeat, reduce, pack, unpack @@ -125,10 +126,10 @@ class RewardModel(pl.LightningModule): {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, ] - if self.deepspeed_offload: - return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) - return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) - # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + # if self.deepspeed_offload: + # return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) + # return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) @property def deepspeed_offload(self) -> bool: @@ -174,7 +175,7 @@ class RewardModel(pl.LightningModule): rm_train=True )[:, -1, :] - # 所有的 token 向量求平均,并输入到打分模块进行打分 + # 计算奖励 reward = self.pred_reward(last_token_embeds) return reward -- GitLab