提交 fa6038f4 编写于 作者: U u010280923

update ppo model

上级 e7dc79af
......@@ -64,7 +64,7 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft"
### Reward Model
```
python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \
python train_rm.py --load_model "./out_sft/rwkv-190.pth" --wandb "" --proj_dir "out_rm" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
......@@ -77,7 +77,7 @@ python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \
### PPO Model (Reinforcement learning from Human Feedback)
```
python train_rm.py --load_sft_model "rwkv-190.pth" --load_rm_model "rm-6.pth" --wandb "" \
python train_rm.py --load_sft_model "./out_sft/rwkv-190.pth" --load_rm_model "./out_rm/rm-2.pth" --wandb "" \
--proj_dir "out_rlhf" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
......
......@@ -9,6 +9,9 @@ from random import randrange
from beartype import beartype
from beartype.typing import List, Optional, Callable, Deque
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torch
from torch import nn
import torch.nn.functional as F
......@@ -18,9 +21,8 @@ from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from pytorch_lightning.utilities import rank_zero_info
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from pytorch_lightning.strategies import DeepSpeedStrategy
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from src.model import RWKV
from src.rlhf.reward import RewardModel
......@@ -126,18 +128,18 @@ class ActorCritic(nn.Module):
mask = None,
return_values = True
):
action_logits = self.actor(
action_logits, _ = self.actor(
x,
finetune_scope = self.actor_lora_scope
ppo_train = True
)
if not return_values:
return action_logits, None
critic_embeds = self.critic(
_, critic_embeds = self.critic(
x,
return_only_embedding = True,
finetune_scope = self.critic_lora_scope
ppo_train = True
)
if self.pooled_values:
......@@ -287,13 +289,7 @@ class RLHF(nn.Module):
# 使用 RWKV 初始化 actor_critic
actor_critic = ActorCritic(
rwkv = self.rwkv,
actor_lora = args.actor_lora,
critic_lora = args.critic_lora,
actor_lora_r = args.actor_lora_r,
critic_lora_r = args.critic_lora_r,
pooled_values = args.critic_pooled_values,
actor_dropout = args.actor_dropout,
critic_dropout = args.critic_dropout
pooled_values = args.critic_pooled_values
).to(self.rwkv.device)
self.actor_critic = actor_critic
......@@ -316,6 +312,61 @@ class RLHF(nn.Module):
@property
def device(self):
return self.accelerate.device
def configure_optimizers(self):
args = self.args
if args.layerwise_lr > 0:
lr_1x = set()
lr_2x = set()
lr_3x = set()
for n, p in self.named_parameters():
if "time_mix" in n:
if args.my_pile_stage == 2:
lr_2x.add(n)
else:
lr_1x.add(n)
elif "time_decay" in n:
if args.my_pile_stage == 2:
lr_3x.add(n)
else:
lr_2x.add(n)
elif "time_first" in n:
lr_3x.add(n)
else:
lr_1x.add(n)
lr_1x = sorted(list(lr_1x))
lr_2x = sorted(list(lr_2x))
lr_3x = sorted(list(lr_3x))
param_dict = {n: p for n, p in self.named_parameters()}
if args.my_pile_stage == 2:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
]
else:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
]
else:
optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
]
if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
@property
def deepspeed_offload(self) -> bool:
strategy = self.trainer.strategy
if isinstance(strategy, DeepSpeedStrategy):
cfg = strategy.config["zero_optimization"]
return cfg.get("offload_optimizer") or cfg.get("offload_param")
return False
@torch.no_grad()
def generate(
......@@ -383,7 +434,7 @@ class RLHF(nn.Module):
mask = action_masks
)
action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_logits = shift(action_logits, shift=1, dim=-2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_len = old_log_probs.shape[-1]
action_probs = action_logits.softmax(dim = -1)
......
......@@ -391,7 +391,7 @@ class rlhf_train_callback(pl.Callback):
def on_train_epoch_start(self, trainer, pl_module):
args = self.args
dataset = trainer.train_dataloader.dataset.datasets
assert "RMDataset" in str(dataset)
assert "PPODataset" in str(dataset)
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
......
......@@ -289,6 +289,10 @@ if __name__ == "__main__":
else:
print(f"{str(shape[0]).ljust(5)} {n}")
if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
train_data = PPODataset(memory)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册