From 609427511a2deb4943fd5714afa9bb2edb2fcaa1 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 16:31:21 +0800 Subject: [PATCH] opt reward model --- src/rlhf/reward.py | 62 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 4 deletions(-) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index eeafe74..32835fb 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -10,6 +10,9 @@ from torch import nn import torch.nn.functional as F import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info +import deepspeed +from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam +from pytorch_lightning.strategies import DeepSpeedStrategy from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange, Reduce @@ -79,11 +82,62 @@ class RewardModel(pl.LightningModule): self.load_state_dict(torch.load(str(path))) def configure_optimizers(self): - # 论文中的参数: - optimizer = torch.optim.Adam(self.parameters(), lr=1e-5, betas=(0.9, 0.95) ) - # optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr_init, betas=self.args.betas) + args = self.args + if args.layerwise_lr > 0: + lr_1x = set() + lr_2x = set() + lr_3x = set() + for n, p in self.named_parameters(): + if "time_mix" in n: + if args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) + elif "time_decay" in n: + if args.my_pile_stage == 2: + lr_3x.add(n) + else: + lr_2x.add(n) + elif "time_first" in n: + lr_3x.add(n) + else: + lr_1x.add(n) + lr_1x = sorted(list(lr_1x)) + lr_2x = sorted(list(lr_2x)) + lr_3x = sorted(list(lr_3x)) + # print('1x', lr_1x) + # print('2x', lr_2x) + # print('3x', lr_3x) + param_dict = {n: p for n, p in self.named_parameters()} + if args.my_pile_stage == 2: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, + ] + else: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + ] + else: + optim_groups = [ + {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + ] - return optimizer + if self.deepspeed_offload: + return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) + return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + + @property + def deepspeed_offload(self) -> bool: + strategy = self.trainer.strategy + if isinstance(strategy, DeepSpeedStrategy): + cfg = strategy.config["zero_optimization"] + return cfg.get("offload_optimizer") or cfg.get("offload_param") + return False def single_forward( self, -- GitLab