From 749886a8cfc57f8b3580607c46f28fcc14ee0577 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 20 Mar 2023 16:04:59 +0800 Subject: [PATCH] opt ppo model --- src/rlhf/ppo.py | 71 ++----------------------------------------------- 1 file changed, 2 insertions(+), 69 deletions(-) diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index b5d5fe3..de40a8f 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -42,7 +42,7 @@ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ ]) @beartype -class ActorCritic(pl.LightningModule): +class ActorCritic(nn.Module): def __init__( self, rwkv: RWKV, @@ -51,7 +51,7 @@ class ActorCritic(pl.LightningModule): pooled_values = False ): super().__init__() - self.actor = rwkv + self.actor = copy.deepcopy(rwkv) self.critic = critic @@ -67,61 +67,6 @@ class ActorCritic(pl.LightningModule): nn.init.zeros_(self.value_head[0].bias) nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2)) - def configure_optimizers(self): - args = self.args - if args.layerwise_lr > 0: - lr_1x = set() - lr_2x = set() - lr_3x = set() - for n, p in self.named_parameters(): - if "time_mix" in n: - if args.my_pile_stage == 2: - lr_2x.add(n) - else: - lr_1x.add(n) - elif "time_decay" in n: - if args.my_pile_stage == 2: - lr_3x.add(n) - else: - lr_2x.add(n) - elif "time_first" in n: - lr_3x.add(n) - else: - lr_1x.add(n) - lr_1x = sorted(list(lr_1x)) - lr_2x = sorted(list(lr_2x)) - lr_3x = sorted(list(lr_3x)) - param_dict = {n: p for n, p in self.named_parameters()} - if args.my_pile_stage == 2: - optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, - ] - else: - optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, - ] - else: - optim_groups = [ - {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, - ] - - if self.deepspeed_offload: - return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) - return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) - # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) - - @property - def deepspeed_offload(self) -> bool: - strategy = self.trainer.strategy - if isinstance(strategy, DeepSpeedStrategy): - cfg = strategy.config["zero_optimization"] - return cfg.get("offload_optimizer") or cfg.get("offload_param") - return False - @torch.no_grad() @eval_decorator def generate( @@ -204,18 +149,6 @@ class ActorCritic(pl.LightningModule): return action_logits, values -# data - -Memory = namedtuple('Memory', [ - 'sequence', - 'prompt_mask', - 'mask', - 'action_prob', - 'action_log_prob', - 'reward', - 'value' -]) - @beartype class ExperienceDataset(Dataset): def __init__( -- GitLab