提交 fa6038f4 编写于 作者: U u010280923

update ppo model

上级 e7dc79af
...@@ -64,7 +64,7 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" ...@@ -64,7 +64,7 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft"
### Reward Model ### Reward Model
``` ```
python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \ python train_rm.py --load_model "./out_sft/rwkv-190.pth" --wandb "" --proj_dir "out_rm" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \ --data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ --micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
...@@ -77,7 +77,7 @@ python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \ ...@@ -77,7 +77,7 @@ python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \
### PPO Model (Reinforcement learning from Human Feedback) ### PPO Model (Reinforcement learning from Human Feedback)
``` ```
python train_rm.py --load_sft_model "rwkv-190.pth" --load_rm_model "rm-6.pth" --wandb "" \ python train_rm.py --load_sft_model "./out_sft/rwkv-190.pth" --load_rm_model "./out_rm/rm-2.pth" --wandb "" \
--proj_dir "out_rlhf" \ --proj_dir "out_rlhf" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \ --data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
......
...@@ -9,6 +9,9 @@ from random import randrange ...@@ -9,6 +9,9 @@ from random import randrange
from beartype import beartype from beartype import beartype
from beartype.typing import List, Optional, Callable, Deque from beartype.typing import List, Optional, Callable, Deque
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -18,9 +21,8 @@ from torch.utils.data import Dataset, DataLoader ...@@ -18,9 +21,8 @@ from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.strategies import DeepSpeedStrategy
from einops import rearrange, repeat from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from einops.layers.torch import Rearrange
from src.model import RWKV from src.model import RWKV
from src.rlhf.reward import RewardModel from src.rlhf.reward import RewardModel
...@@ -126,18 +128,18 @@ class ActorCritic(nn.Module): ...@@ -126,18 +128,18 @@ class ActorCritic(nn.Module):
mask = None, mask = None,
return_values = True return_values = True
): ):
action_logits = self.actor( action_logits, _ = self.actor(
x, x,
finetune_scope = self.actor_lora_scope ppo_train = True
) )
if not return_values: if not return_values:
return action_logits, None return action_logits, None
critic_embeds = self.critic( _, critic_embeds = self.critic(
x, x,
return_only_embedding = True, return_only_embedding = True,
finetune_scope = self.critic_lora_scope ppo_train = True
) )
if self.pooled_values: if self.pooled_values:
...@@ -287,13 +289,7 @@ class RLHF(nn.Module): ...@@ -287,13 +289,7 @@ class RLHF(nn.Module):
# 使用 RWKV 初始化 actor_critic # 使用 RWKV 初始化 actor_critic
actor_critic = ActorCritic( actor_critic = ActorCritic(
rwkv = self.rwkv, rwkv = self.rwkv,
actor_lora = args.actor_lora, pooled_values = args.critic_pooled_values
critic_lora = args.critic_lora,
actor_lora_r = args.actor_lora_r,
critic_lora_r = args.critic_lora_r,
pooled_values = args.critic_pooled_values,
actor_dropout = args.actor_dropout,
critic_dropout = args.critic_dropout
).to(self.rwkv.device) ).to(self.rwkv.device)
self.actor_critic = actor_critic self.actor_critic = actor_critic
...@@ -316,6 +312,61 @@ class RLHF(nn.Module): ...@@ -316,6 +312,61 @@ class RLHF(nn.Module):
@property @property
def device(self): def device(self):
return self.accelerate.device return self.accelerate.device
def configure_optimizers(self):
args = self.args
if args.layerwise_lr > 0:
lr_1x = set()
lr_2x = set()
lr_3x = set()
for n, p in self.named_parameters():
if "time_mix" in n:
if args.my_pile_stage == 2:
lr_2x.add(n)
else:
lr_1x.add(n)
elif "time_decay" in n:
if args.my_pile_stage == 2:
lr_3x.add(n)
else:
lr_2x.add(n)
elif "time_first" in n:
lr_3x.add(n)
else:
lr_1x.add(n)
lr_1x = sorted(list(lr_1x))
lr_2x = sorted(list(lr_2x))
lr_3x = sorted(list(lr_3x))
param_dict = {n: p for n, p in self.named_parameters()}
if args.my_pile_stage == 2:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
]
else:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
]
else:
optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
]
if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
@property
def deepspeed_offload(self) -> bool:
strategy = self.trainer.strategy
if isinstance(strategy, DeepSpeedStrategy):
cfg = strategy.config["zero_optimization"]
return cfg.get("offload_optimizer") or cfg.get("offload_param")
return False
@torch.no_grad() @torch.no_grad()
def generate( def generate(
...@@ -383,7 +434,7 @@ class RLHF(nn.Module): ...@@ -383,7 +434,7 @@ class RLHF(nn.Module):
mask = action_masks mask = action_masks
) )
action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token action_logits = shift(action_logits, shift=1, dim=-2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_len = old_log_probs.shape[-1] action_len = old_log_probs.shape[-1]
action_probs = action_logits.softmax(dim = -1) action_probs = action_logits.softmax(dim = -1)
......
...@@ -391,7 +391,7 @@ class rlhf_train_callback(pl.Callback): ...@@ -391,7 +391,7 @@ class rlhf_train_callback(pl.Callback):
def on_train_epoch_start(self, trainer, pl_module): def on_train_epoch_start(self, trainer, pl_module):
args = self.args args = self.args
dataset = trainer.train_dataloader.dataset.datasets dataset = trainer.train_dataloader.dataset.datasets
assert "RMDataset" in str(dataset) assert "PPODataset" in str(dataset)
dataset.global_rank = trainer.global_rank dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size dataset.world_size = trainer.world_size
......
...@@ -289,6 +289,10 @@ if __name__ == "__main__": ...@@ -289,6 +289,10 @@ if __name__ == "__main__":
else: else:
print(f"{str(shape[0]).ljust(5)} {n}") print(f"{str(shape[0]).ljust(5)} {n}")
if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
train_data = PPODataset(memory) train_data = PPODataset(memory)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册