diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index a244bb3357c191424bbbd24d36fe65f771fc4d06..b5d5fe33a9808c1bf73ce12fdf21455ead0a6e71 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -42,7 +42,7 @@ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ ]) @beartype -class ActorCritic(nn.Module): +class ActorCritic(pl.LightningModule): def __init__( self, rwkv: RWKV, @@ -67,6 +67,61 @@ class ActorCritic(nn.Module): nn.init.zeros_(self.value_head[0].bias) nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2)) + def configure_optimizers(self): + args = self.args + if args.layerwise_lr > 0: + lr_1x = set() + lr_2x = set() + lr_3x = set() + for n, p in self.named_parameters(): + if "time_mix" in n: + if args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) + elif "time_decay" in n: + if args.my_pile_stage == 2: + lr_3x.add(n) + else: + lr_2x.add(n) + elif "time_first" in n: + lr_3x.add(n) + else: + lr_1x.add(n) + lr_1x = sorted(list(lr_1x)) + lr_2x = sorted(list(lr_2x)) + lr_3x = sorted(list(lr_3x)) + param_dict = {n: p for n, p in self.named_parameters()} + if args.my_pile_stage == 2: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, + ] + else: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + ] + else: + optim_groups = [ + {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + ] + + if self.deepspeed_offload: + return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) + return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + + @property + def deepspeed_offload(self) -> bool: + strategy = self.trainer.strategy + if isinstance(strategy, DeepSpeedStrategy): + cfg = strategy.config["zero_optimization"] + return cfg.get("offload_optimizer") or cfg.get("offload_param") + return False + @torch.no_grad() @eval_decorator def generate(