From b5588704241bc550f708e4a330c9c67140daf8f8 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 20 Mar 2023 15:20:36 +0800 Subject: [PATCH] opt ppo model --- src/rlhf/ppo.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index a244bb3..b5d5fe3 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -42,7 +42,7 @@ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ ]) @beartype -class ActorCritic(nn.Module): +class ActorCritic(pl.LightningModule): def __init__( self, rwkv: RWKV, @@ -67,6 +67,61 @@ class ActorCritic(nn.Module): nn.init.zeros_(self.value_head[0].bias) nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2)) + def configure_optimizers(self): + args = self.args + if args.layerwise_lr > 0: + lr_1x = set() + lr_2x = set() + lr_3x = set() + for n, p in self.named_parameters(): + if "time_mix" in n: + if args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) + elif "time_decay" in n: + if args.my_pile_stage == 2: + lr_3x.add(n) + else: + lr_2x.add(n) + elif "time_first" in n: + lr_3x.add(n) + else: + lr_1x.add(n) + lr_1x = sorted(list(lr_1x)) + lr_2x = sorted(list(lr_2x)) + lr_3x = sorted(list(lr_3x)) + param_dict = {n: p for n, p in self.named_parameters()} + if args.my_pile_stage == 2: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, + ] + else: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + ] + else: + optim_groups = [ + {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + ] + + if self.deepspeed_offload: + return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) + return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + + @property + def deepspeed_offload(self) -> bool: + strategy = self.trainer.strategy + if isinstance(strategy, DeepSpeedStrategy): + cfg = strategy.config["zero_optimization"] + return cfg.get("offload_optimizer") or cfg.get("offload_param") + return False + @torch.no_grad() @eval_decorator def generate( -- GitLab