From fa6038f4f92dcc13d6c1d43b9ca4508dcb5312d4 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 17 Mar 2023 17:38:00 +0800 Subject: [PATCH] update ppo model --- README.md | 4 +-- src/rlhf/ppo.py | 81 ++++++++++++++++++++++++++++++++++++++++--------- src/trainer.py | 2 +- train_ppo.py | 4 +++ 4 files changed, 73 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index a758549..244ecf6 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" ### Reward Model ``` -python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \ +python train_rm.py --load_model "./out_sft/rwkv-190.pth" --wandb "" --proj_dir "out_rm" \ --data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \ --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ --micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ @@ -77,7 +77,7 @@ python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \ ### PPO Model (Reinforcement learning from Human Feedback) ``` -python train_rm.py --load_sft_model "rwkv-190.pth" --load_rm_model "rm-6.pth" --wandb "" \ +python train_rm.py --load_sft_model "./out_sft/rwkv-190.pth" --load_rm_model "./out_rm/rm-2.pth" --wandb "" \ --proj_dir "out_rlhf" \ --data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \ --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 859d17e..faf072c 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -9,6 +9,9 @@ from random import randrange from beartype import beartype from beartype.typing import List, Optional, Callable, Deque +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + import torch from torch import nn import torch.nn.functional as F @@ -18,9 +21,8 @@ from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence from pytorch_lightning.utilities import rank_zero_info - -from einops import rearrange, repeat -from einops.layers.torch import Rearrange +from pytorch_lightning.strategies import DeepSpeedStrategy +from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from src.model import RWKV from src.rlhf.reward import RewardModel @@ -126,18 +128,18 @@ class ActorCritic(nn.Module): mask = None, return_values = True ): - action_logits = self.actor( + action_logits, _ = self.actor( x, - finetune_scope = self.actor_lora_scope + ppo_train = True ) if not return_values: return action_logits, None - critic_embeds = self.critic( + _, critic_embeds = self.critic( x, return_only_embedding = True, - finetune_scope = self.critic_lora_scope + ppo_train = True ) if self.pooled_values: @@ -287,13 +289,7 @@ class RLHF(nn.Module): # 使用 RWKV 初始化 actor_critic actor_critic = ActorCritic( rwkv = self.rwkv, - actor_lora = args.actor_lora, - critic_lora = args.critic_lora, - actor_lora_r = args.actor_lora_r, - critic_lora_r = args.critic_lora_r, - pooled_values = args.critic_pooled_values, - actor_dropout = args.actor_dropout, - critic_dropout = args.critic_dropout + pooled_values = args.critic_pooled_values ).to(self.rwkv.device) self.actor_critic = actor_critic @@ -316,6 +312,61 @@ class RLHF(nn.Module): @property def device(self): return self.accelerate.device + + def configure_optimizers(self): + args = self.args + if args.layerwise_lr > 0: + lr_1x = set() + lr_2x = set() + lr_3x = set() + for n, p in self.named_parameters(): + if "time_mix" in n: + if args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) + elif "time_decay" in n: + if args.my_pile_stage == 2: + lr_3x.add(n) + else: + lr_2x.add(n) + elif "time_first" in n: + lr_3x.add(n) + else: + lr_1x.add(n) + lr_1x = sorted(list(lr_1x)) + lr_2x = sorted(list(lr_2x)) + lr_3x = sorted(list(lr_3x)) + param_dict = {n: p for n, p in self.named_parameters()} + if args.my_pile_stage == 2: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, + ] + else: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + ] + else: + optim_groups = [ + {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + ] + + if self.deepspeed_offload: + return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) + return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + + @property + def deepspeed_offload(self) -> bool: + strategy = self.trainer.strategy + if isinstance(strategy, DeepSpeedStrategy): + cfg = strategy.config["zero_optimization"] + return cfg.get("offload_optimizer") or cfg.get("offload_param") + return False @torch.no_grad() def generate( @@ -383,7 +434,7 @@ class RLHF(nn.Module): mask = action_masks ) - action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token + action_logits = shift(action_logits, shift=1, dim=-2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token action_len = old_log_probs.shape[-1] action_probs = action_logits.softmax(dim = -1) diff --git a/src/trainer.py b/src/trainer.py index 3ff07af..d309b79 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -391,7 +391,7 @@ class rlhf_train_callback(pl.Callback): def on_train_epoch_start(self, trainer, pl_module): args = self.args dataset = trainer.train_dataloader.dataset.datasets - assert "RMDataset" in str(dataset) + assert "PPODataset" in str(dataset) dataset.global_rank = trainer.global_rank dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) dataset.world_size = trainer.world_size diff --git a/train_ppo.py b/train_ppo.py index 5e9c892..951f964 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -289,6 +289,10 @@ if __name__ == "__main__": else: print(f"{str(shape[0]).ljust(5)} {n}") + if "deepspeed" in args.strategy: + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + train_data = PPODataset(memory) data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) -- GitLab