提交 e7dc79af 编写于 作者: U u010280923

opt ppo model

上级 65604ada
......@@ -74,7 +74,18 @@ python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \
```
### 接入RLHF(Reinforcement Learning with Human Feedback)
### PPO Model (Reinforcement learning from Human Feedback)
```
python train_rm.py --load_sft_model "rwkv-190.pth" --load_rm_model "rm-6.pth" --wandb "" \
--proj_dir "out_rlhf" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
--my_qa_mask 1
```
......@@ -323,4 +323,47 @@ class RMDataset(Dataset):
m_p = torch.tensor(prompt_prefer_mask, dtype=torch.long)
m_a = torch.tensor(prompt_alter_mask, dtype=torch.long)
return x_p, x_a, m_p, m_a
\ No newline at end of file
return x_p, x_a, m_p, m_a
class PPODataset(Dataset):
def __init__(self, memory):
self.data = memory
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# todo(luxin) 是否需要 padding ???
sequence, \
prompt_mask, \
mask, \
action_prob, \
action_log_prob, \
reward, \
value = self.data[index]
return sequence, prompt_mask, mask, action_prob, action_log_prob, reward, value
def load_prompt_data_4_ppo(args):
prompt_token_ids = []
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
tokenizer = TOKENIZER(WORD_NAME)
pf = pd.read_csv(args.data_file)
for index, row in pf.iterrows():
prompt = row["prompt"]
prompt_token_ids.append(tokenizer.tokenizer.encode(prompt))
prompt_token_ids = torch.tensor(prompt_token_ids, dtype=torch.long)
return prompt_token_ids
......@@ -12,6 +12,15 @@ from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from tqdm import tqdm
from einops import pack
from einops import unpack
from src.rlhf.utils import exists
from src.rlhf.utils import gumbel_sample
from src.rlhf.utils import top_k
from src.rlhf.utils import identity
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
try:
......@@ -429,7 +438,7 @@ class RWKV(pl.LightningModule):
return cfg.get("offload_optimizer") or cfg.get("offload_param")
return False
def forward(self, idx, extra_embed=None, rm_train=False):
def forward(self, idx, extra_embed=None, rm_train=False, ppo_train=False):
args = self.args
B, T = idx.size()
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
......@@ -454,15 +463,15 @@ class RWKV(pl.LightningModule):
else:
x = block(x)
x = self.ln_out(x)
embeds = self.ln_out(x)
# 用于 RM 模型的编码
if rm_train is True:
return x
return embeds
if args.head_qk > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
q = self.head_q(embeds)[:, :T, :]
k = self.head_k(embeds)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
......@@ -473,11 +482,66 @@ class RWKV(pl.LightningModule):
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
x = self.head(x) + c
logits = self.head(embeds) + c
else:
x = self.head(x)
return x
logits = self.head(embeds)
# 用于 PPO 模型
if ppo_train is True:
return logits, embeds
return logits
@torch.no_grad()
def generate(
self,
seq_len,
prompt = None,
temperature = 1.,
filter_logits_fn = top_k,
filter_thres = 0.9,
pad_value = 0.,
eos_token = None,
return_seq_without_prompt = True,
use_tqdm = False,
**kwargs
):
'''
'''
prompt, leading_dims = pack([prompt], '* n')
n, out = prompt.shape[-1], prompt.clone()
wrapper_fn = identity if not use_tqdm else tqdm
sample_num_times = max(1, seq_len - prompt.shape[-1])
for _ in wrapper_fn(range(sample_num_times)):
logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs)
logits, embeds = logits[:, -1], embeds[:, -1]
if exists(filter_logits_fn):
logits = filter_logits_fn(logits, thres = filter_thres)
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
out, _ = pack([out, sample], 'b *')
if exists(eos_token):
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, pad_value)
break
out, = unpack(out, leading_dims, '* n')
if not return_seq_without_prompt:
return out
return out[..., n:]
def training_step(self, batch, batch_idx):
args = self.args
......
此差异已折叠。
import math
from pathlib import Path
import copy
from tqdm import tqdm
from functools import partial
from collections import deque, namedtuple
from random import randrange
from beartype import beartype
from beartype.typing import List, Optional, Callable, Deque
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from pytorch_lightning.utilities import rank_zero_info
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from src.model import RWKV
from src.rlhf.reward import RewardModel
from src.rlhf.optimizer import get_optimizer
from src.rlhf.utils import masked_mean, eval_decorator
from accelerate import Accelerator
# actor critic - PaLM with lora
PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
'actions',
'sequence',
'mask',
'prompt_mask',
'action_logits',
'values'
])
@beartype
class ActorCritic(nn.Module):
def __init__(
self,
rwkv: RWKV,
critic_palm: Optional[RWKV] = None,
pooled_values = False,
actor_lora = True,
critic_lora = True,
actor_lora_r = 8,
critic_lora_r = 8,
actor_lora_scope = 'actor',
critic_lora_scope = 'critic',
actor_dropout = 0.,
critic_dropout = 0.
):
super().__init__()
self.actor_palm = rwkv
self.critic_palm = critic_palm
if not exists(self.critic_palm):
self.critic_palm = copy.deepcopy(rwkv)
self.actor_palm.set_dropout(actor_dropout)
self.critic_palm.set_dropout(critic_dropout)
self.actor_lora = actor_lora
self.critic_lora = critic_lora
self.actor_lora_scope = actor_lora_scope if actor_lora else None
self.critic_lora_scope = critic_lora_scope if critic_lora else None
if self.actor_lora:
self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)
if self.critic_lora:
self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)
self.pooled_values = pooled_values
self.value_head = nn.Sequential(
nn.Linear(rwkv.dim, 1),
Rearrange('... 1 -> ...')
)
nn.init.zeros_(self.value_head[0].bias)
nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))
def actor_parameters(self):
if not self.actor_lora:
return self.actor_palm.parameters()
return [
*self.actor_palm.finetune_parameters(self.actor_lora_scope)
]
def critic_parameters(self):
if not self.actor_lora:
return [*self.critic_palm.parameters(), *self.value_head.parameters()]
return [
*self.critic_palm.finetune_parameters(self.critic_lora_scope),
*self.value_head.parameters()
]
@torch.no_grad()
@eval_decorator
def generate(
self,
state,
max_seq_len,
eos_token = None,
return_values = False,
**kwargs
):
actions = self.actor_palm.generate(
max_seq_len,
prompt = state,
eos_token = eos_token,
finetune_scope = self.actor_lora_scope,
use_tqdm = True,
**kwargs
)
sequence = torch.cat((state, actions), dim = -1)
action_len = actions.shape[-1]
state_len = state.shape[-1]
prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len
prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0])
action_mask = ~prompt_mask
mask = None
if exists(eos_token):
mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
mask = F.pad(mask, (1, -1), value = True) # include eos token
action_mask &= mask
action_logits, value = self.forward(
sequence,
mask = action_mask,
return_values = return_values
)
return PPOActionCriticReturn(
actions,
sequence,
mask,
prompt_mask,
action_logits,
value
)
def forward(
self,
x,
mask = None,
return_values = True
):
action_logits = self.actor_palm(
x,
finetune_scope = self.actor_lora_scope
)
if not return_values:
return action_logits, None
critic_embeds = self.critic_palm(
x,
return_only_embedding = True,
finetune_scope = self.critic_lora_scope
)
if self.pooled_values:
critic_embeds = shift(critic_embeds, shift = 1, dim = -2)
critic_embeds = masked_mean(critic_embeds, mask, dim = 1)
values = self.value_head(critic_embeds)
return action_logits, values
# data
Memory = namedtuple('Memory', [
'sequence',
'prompt_mask',
'mask',
'action_prob',
'action_log_prob',
'reward',
'value'
])
@beartype
class ExperienceDataset(Dataset):
def __init__(
self,
data: List[torch.Tensor],
device = None
):
super().__init__()
self.data = data
self.device = device
def __len__(self):
return self.data[0].shape[0]
def __getitem__(self, ind):
return tuple(map(lambda t: t[ind].to(self.device), self.data))
def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs):
ds = ExperienceDataset(data, device = device)
return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs)
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def masked_normalize(t, eps = 1e-5, mask = None, dim = None):
dim = default(dim, tuple(range(t.ndim)))
kwargs = dict(dim = dim, keepdim = True)
mean = masked_mean(t, mask = mask, **kwargs)
mean_centered = t - mean
var = masked_mean(mean_centered ** 2, mask = mask, **kwargs)
return mean_centered * var.clamp(min = eps).rsqrt()
def pad_sequence_fixed(sequences, *args, **kwargs):
first_el = sequences[0]
has_no_dimension = first_el.ndim == 0
# if no dimensions, add a single dimension
if has_no_dimension:
sequences = tuple(map(lambda t: t[None], sequences))
out = pad_sequence(sequences, *args, **kwargs)
if has_no_dimension:
out = rearrange(out, '... 1 -> ...')
return out
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def log_prob(prob, indices):
assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match'
return log(prob.gather(-1, indices[..., None])).squeeze(-1)
def shift(t, value = 0, shift = 1, dim = -1):
zeros = (0, 0) * (-dim - 1)
return F.pad(t, (*zeros, shift, -shift), value = value)
def masked_entropy(prob, dim = -1, mask = None):
entropies = (prob * log(prob)).sum(dim = -1)
return masked_mean(entropies, mask = mask).mean()
def masked_kl_div(prob1, prob2, mask = None):
"""
need to account for variable sequence lengths, therefore not using the built-in functional version
"""
kl_divs = (prob1 * (log(prob2) - log(prob1))).sum(dim = -1)
if not exists(mask):
return kl_divs.mean()
return masked_mean(kl_divs, mask).mean()
def clipped_value_loss(values, rewards, old_values, clip):
value_clipped = old_values + (values - old_values).clamp(-clip, clip)
value_loss_1 = (value_clipped.flatten() - rewards) ** 2
value_loss_2 = (values.flatten() - rewards) ** 2
return torch.mean(torch.max(value_loss_1, value_loss_2))
# rlhf trainer
@beartype
class RLHFTrainer(nn.Module):
def __init__(
self,
args,
accelerate_kwargs: dict = {}
):
super().__init__()
self.args = args
self.accelerate = Accelerator(**accelerate_kwargs)
# 加载 RWKV 模型
rwkv = RWKV(args)
if len(args.load_sft_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
rank_zero_info(f"########## Loading {args.load_sft_model}... ##########")
try:
load_dict = torch.load(args.load_sft_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_sft_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
self.rwkv = rwkv
# 使用 RWKV 初始化 actor_critic
actor_critic = ActorCritic(
rwkv = self.rwkv,
actor_lora = args.actor_lora,
critic_lora = args.critic_lora,
actor_lora_r = args.actor_lora_r,
critic_lora_r = args.critic_lora_r,
pooled_values = args.critic_pooled_values,
actor_dropout = args.actor_dropout,
critic_dropout = args.critic_dropout
).to(self.rwkv.device)
self.actor_critic = actor_critic
# 加载 reward_model,并将 reward_model 设置为 evaluation 模式
reward_model = RewardModel(args)
reward_model.load(args.load_rm_model)
self.reward_model = reward_model.eval()
# optimizers
self.actor_optim = get_optimizer(actor_critic.actor_parameters(), lr = self.args.actor_lr, wd = self.args.actor_wd, betas = self.args.betas, eps = self.args.actor_adam_eps, use_lion = self.args.use_lion)
self.critic_optim = get_optimizer(actor_critic.critic_parameters(), lr = self.args.critic_lr, wd = self.args.critic_wd, betas = self.args.betas, eps = self.args.critic_adam_eps, use_lion = self.args.use_lion)
# prepare with accelerator
(
self.actor_critic,
self.reward_model,
self.actor_optim,
self.critic_optim
) = self.accelerate.prepare(
self.actor_critic,
self.reward_model,
self.actor_optim,
self.critic_optim
)
def print(self, msg):
return self.accelerate.print(msg)
def save(self, filepath = './checkpoint.pt'):
torch.save(self.actor_critic.state_dict(), filepath)
def load(self, filepath = './checkpoint.pt'):
state_dict = torch.load(filepath)
self.actor_critic.load_state_dict(state_dict)
@property
def device(self):
return self.accelerate.device
@torch.no_grad()
def generate(
self,
max_seq_len,
*args,
prompt,
num_samples = 4, # sample 4 per prompt and select the one with highest reward
**kwargs
):
assert prompt.ndim == 1, 'only one prompt allowed at a time for now'
prompt = repeat(prompt, 'n -> b n', b = num_samples)
actor_critic = self.accelerate.unwrap_model(self.actor_critic)
reward_model = self.accelerate.unwrap_model(self.reward_model)
actor_critic.eval()
(
actions,
sequences,
mask,
prompt_mask,
action_logits,
_
) = actor_critic.generate(
prompt,
*args,
max_seq_len = max_seq_len,
return_values = False,
**kwargs
)
rewards = reward_model(
sequences,
prompt_mask = prompt_mask,
mask = mask,
sample = True
)
best_sequence_index = rewards.topk(1, dim = -1).indices
best_sequence = sequences[best_sequence_index]
best_sequence = rearrange(best_sequence, '1 ... -> ...')
return best_sequence
def learn(
self,
memories: Deque[Memory]
):
# stack all data stored in the memories
all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories)))
# prepare dataloader for policy phase training
dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device)
self.actor_critic.train()
# PPO training
for _ in range(self.epochs):
for (
sequences,
prompt_masks,
masks,
old_action_probs,
old_log_probs,
rewards,
old_values
) in dl:
action_masks = ~prompt_masks & masks
action_logits, values = self.actor_critic(
sequences,
mask = action_masks
)
action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_len = old_log_probs.shape[-1]
action_probs = action_logits.softmax(dim = -1)
action_log_probs = log_prob(action_probs, sequences)
action_log_probs = action_log_probs[:, -action_len:]
# calculate entropies, taking into account which part of the sequence is actually an action
entropies = masked_entropy(action_probs, mask = action_masks)
# calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not
kl_div_loss = 0.
if self.args.kl_div_loss_weight > 0:
kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.args.kl_div_loss_weight
# handle non-pooled values
normalize_kwargs = dict()
if old_values.ndim == 2:
old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values))
old_values = old_values[:, -action_len:]
values = values[:, -action_len:]
rewards = rearrange(rewards, 'b -> b 1')
normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:])
if values.ndim < rewards.ndim:
values = rearrange(values, '... -> ... 1')
# calculate clipped surrogate objective, classic PPO loss
ratios = (action_log_probs - old_log_probs).exp()
advantages = masked_normalize(rewards - old_values, **normalize_kwargs)
if advantages.ndim == 1:
advantages = rearrange(advantages, 'b -> b 1')
surr1 = ratios * advantages
surr2 = ratios.clamp(1 - self.args.eps_clip, 1 + self.args.eps_clip) * advantages
policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies
# combine losses
loss = policy_loss.mean() + kl_div_loss
# update actor
self.accelerate.backward(loss)
self.print(f'policy_loss: {loss.item():.3f}')
if exists(self.args.max_norm):
self.accelerator.clip_grad_norm_(self.actor_critic.actor_parameters(), self.args.max_norm)
self.actor_optim.step()
self.actor_optim.zero_grad()
# calculate value loss and update value network separate from policy network
value_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip)
value_loss = value_loss.mean()
self.print(f'critic_loss: {value_loss.item():.3f}')
self.accelerate.backward(value_loss)
if exists(self.args.max_norm):
self.accelerator.clip_grad_norm_(self.actor_critic.critic_parameters(), self.args.max_norm)
self.critic_optim.step()
self.critic_optim.zero_grad()
def train(
self,
num_episodes = 50000,
max_timesteps = 500,
update_timesteps = 5000,
max_batch_size = 16,
eos_token = None,
temperature = 1.
):
device = self.device
time = 0
memories = deque([])
for eps in tqdm(range(num_episodes), desc = 'episodes'):
for timestep in range(max_timesteps):
time += 1
# select a bunch of random states (prompts)
# and get the action (sampled sequence from palm as well as the action probs)
# also calculate the reward using reward model and store
rand_prompt_index = randrange(0, self.num_prompts)
state = self.prompt_token_ids[rand_prompt_index]
# remove padding from state
state_mask = state != self.args.pad_value
state = state[state_mask]
# get predicted sequence
(
actions,
sequence,
mask,
prompt_mask,
action_logits,
value
) = self.actor_critic.generate(
rearrange(state, 'n -> 1 n'),
max_seq_len = self.args.ctx_len,
eos_token = eos_token,
temperature = temperature,
return_values = True
)
action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_prob = action_logits.softmax(dim = -1)
action_len = actions.shape[-1]
action_log_prob = log_prob(action_prob, sequence)
action_log_prob = action_log_prob[:, -action_len:]
actions = rearrange(actions, '1 ... -> ...')
# get reward as given by supervised trained reward model
sequence = torch.cat((state, actions), dim = 0)
prompt_length = len(state)
prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length
sequence = rearrange(sequence, 'n -> 1 n')
prompt_mask = rearrange(prompt_mask, 'n -> 1 n')
mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device)
reward = self.reward_model(
sequence,
prompt_mask = prompt_mask,
mask = mask,
sample = True
)
detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...')
# store memory for learning
memories.append(Memory(*map(detach_to_cpu_, (
sequence,
prompt_mask,
mask,
action_prob,
action_log_prob,
reward,
value
))))
# learn from the stored memories
if time % update_timesteps == 0:
self.learn(memories)
memories.clear()
print('rlhf training complete')
......@@ -8,6 +8,9 @@ from einops import rearrange
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
# decorators
def eval_decorator(fn):
......
......@@ -287,6 +287,142 @@ class rm_train_callback(pl.Callback):
trainer.my_loss_count = 0
class rlhf_train_callback(pl.Callback):
def __init__(self, args):
super().__init__()
self.args = args
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
args = self.args
# if args.cuda_cleanup > 0:
# torch.cuda.empty_cache()
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
# LR schedule
w_step = args.warmup_steps
if args.lr_final == args.lr_init or args.epoch_count == 0:
lr = args.lr_init
else:
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
progress = (decay_step - w_step + 1) / (decay_total - w_step)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
if trainer.global_step < w_step:
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
# if trainer.is_global_zero:
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0:
param_group["lr"] = lr * param_group["my_lr_scale"]
# print(param_group["lr"], param_group["my_lr_scale"])
else:
param_group["lr"] = lr
trainer.my_lr = lr
# rank_zero_info(f"{real_step} {lr}")
if trainer.global_step == 0:
if trainer.is_global_zero: # logging
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
try:
print(f"\n{trainer.strategy.config}\n")
trainer.my_log.write(f"{trainer.strategy.config}\n")
except:
pass
trainer.my_log.flush()
if len(args.wandb) > 0:
print("Login to wandb...")
import wandb
wandb.init(
project=args.wandb,
name=args.run_name + " " + args.my_timestamp,
config=args,
save_code=False,
)
trainer.my_wandb = wandb
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
args = self.args
if trainer.is_global_zero: # logging
t_now = time.time_ns()
token_per_step = args.ctx_len * args.real_bsz
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
kt_s = 0
try:
t_cost = (t_now - trainer.my_time_ns) / 1e9
kt_s = token_per_step / t_cost / 1000
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
except:
pass
trainer.my_time_ns = t_now
trainer.my_loss = trainer.my_loss_all.float().mean().item()
trainer.my_loss_sum += trainer.my_loss
trainer.my_loss_count += 1
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
# self.log("s", real_step, prog_bar=True, on_step=True)
if len(args.wandb) > 0:
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
if kt_s > 0:
lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step))
if args.magic_prime > 0:
if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
my_save(
to_save_dict,
f"{args.proj_dir}/rlhf-final.pth",
)
def on_train_epoch_start(self, trainer, pl_module):
args = self.args
dataset = trainer.train_dataloader.dataset.datasets
assert "RMDataset" in str(dataset)
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
def on_train_epoch_end(self, trainer, pl_module):
args = self.args
if trainer.is_global_zero: # logging & save state_dict
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
if args.data_type == 'wds_img':
raw_dict = pl_module.state_dict()
to_save_dict = {}
for k in raw_dict:
if k.startswith('encoder.') or k.startswith('decoder.'):
to_save_dict[k] = raw_dict[k]
else:
to_save_dict = pl_module.state_dict()
try:
my_save(
to_save_dict,
f"{args.proj_dir}/rlhf-{args.epoch_begin + trainer.current_epoch}.pth",
)
except Exception as e:
print('Error\n\n', e, '\n\n')
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush()
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
@rank_zero_only
def generate_init_weight(model, init_weight_name):
mm = model.generate_init_weight()
......
......@@ -6,38 +6,293 @@
'''
# here put the import lib
import torch
from src.model import RWKV
from src.rlhf.reward import RewardModel
from src.rlhf.ppo import RLHFTrainer
# load your pretrained RWKV
# todo(luxin) 加载 SFT 之后的预训练模型
rwkv_model = RWKV()
# palm.load('./path/to/pretrained/palm.pt')
# load your pretrained reward model
# todo(luxin) 加载训练好的 reward Model
reward_model = RewardModel(
rwkv_model,
num_binned_output = 5
)
# reward_model.load('./path/to/pretrained/reward_model.pt')
# ready your list of prompts for reinforcement learning
# todo(luxin) 读入 Prompts 数据集(此处的 Prompt 与 SFT、RM 阶段的 Prompt 要不一样)
prompts = torch.randint(0, 256, (50000, 512)) # 50k prompts
# pass it all to the trainer and train
# 训练 PPO 模型
trainer = RLHFTrainer(
palm = palm,
reward_model = reward_model,
prompt_token_ids = prompts
)
trainer.train(num_episodes = 100)
# then, if it succeeded...
# generate say 10 samples and use the reward model to return the best one
answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,)
print(answer)
\ No newline at end of file
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
if __name__ == "__main__":
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
rank_zero_info("########## work in progress ##########")
########################################################################################################
#
# example: train a simple L12-D768 RWKV on dummy data
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "" --data_type "dummy" --vocab_size 0 \
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: train a simple L6-D512 RWKV from scratch on enwik8
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser = ArgumentParser()
parser.add_argument("--load_sft_model", default="", type=str) # full path, with .pth
parser.add_argument("--load_rm_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--dim_att", default=0, type=int)
parser.add_argument("--dim_ffn", default=0, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int)
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser.add_argument("--my_img_version", default=0, type=str)
parser.add_argument("--my_img_size", default=0, type=int)
parser.add_argument("--my_img_bit", default=0, type=int)
parser.add_argument("--my_img_clip", default='x', type=str)
parser.add_argument("--my_img_clip_scale", default=1, type=float)
parser.add_argument("--my_img_l1_scale", default=0, type=float)
parser.add_argument("--my_img_encoder", default='x', type=str)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser.add_argument("--my_sample_len", default=0, type=int)
parser.add_argument("--my_ffn_shift", default=1, type=int)
parser.add_argument("--my_att_shift", default=1, type=int)
parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
# PPO model parameters
parser.add_argument("--critic_pooled_values", default=True, type=bool)
parser.add_argument("--max_norm", default=None, type=float)
parser.add_argument("--kl_div_loss_weight", default=0.1, type=float) # between old action probs and new action probs - not sure what the right value is
parser.add_argument("--eps_clip", default=0.2, type=float)
parser.add_argument("--value_clip", default=0.4, type=float)
parser.add_argument("--beta_s", default=0.01, type=float)
parser.add_argument("--actor_lr", default=1e-4, type=float)
parser.add_argument("--critic_lr", default=1e-4, type=float)
parser.add_argument("--actor_wd", default=0., type=float)
parser.add_argument("--critic_wd", default=0., type=float)
parser.add_argument("--actor_adam_eps", default=1e-7, type=float)
parser.add_argument("--critic_adam_eps", default=1e-7, type=float)
parser.add_argument("--pad_value", default=1, type=float) # token pad value
parser.add_argument("--use_lion", default=False, type=bool)
parser.add_argument("--num_episodes", default=50000, type=int)
parser.add_argument("--max_timesteps", default=500, type=int)
parser.add_argument("--update_timesteps", default=5000, type=int)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
########################################################################################################
import os, warnings, math, datetime, sys, time
import numpy as np
import torch
from torch.utils.data import DataLoader
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
# os.environ["WDS_SHOW_SEED"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False
args.replace_sampler_ddp = False
args.logger = False
args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0
args.check_val_every_n_epoch = int(1e20)
args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
os.environ["RWKV_MY_TESTING"] = args.my_testing
if args.dim_att <= 0:
args.dim_att = args.n_embd
if args.dim_ffn <= 0:
args.dim_ffn = args.n_embd * 4
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir)
samples_per_epoch = args.epoch_steps * args.real_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len
rank_zero_info(
f"""
############################################################################
#
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
#
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
#
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
#
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
#
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
#
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
#
############################################################################
"""
)
rank_zero_info(str(vars(args)) + "\n")
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32":
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
os.environ["RWKV_JIT_ON"] = "1"
if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT_ON"] = "0"
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
if args.precision == "fp32":
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if "32" in args.precision:
args.precision = 32
elif args.precision == "fp16":
args.precision = 16
else:
args.precision = "bf16"
########################################################################################################
from tqdm import tqdm
from collections import deque, namedtuple
from einops import rearrange
from src.dataset import PPODataset, load_prompt_data_4_ppo
from src.rlhf.ppo import RLHF
from src.trainer import rlhf_train_callback
# 用于 PPO 训练的数据,需要与 environment 交互获得
memory = []
# 读入训练数据集
prompts = load_prompt_data_4_ppo(args)
# PPO 模型
rlhf_model = RLHF(args)
# 模型训练
# trainer
trainer = Trainer.from_argparse_args(
args,
callbacks=[rlhf_train_callback(args)],
)
time_cnt = 0
for eps in tqdm(range(args.num_episodes), desc = 'episodes'):
for timestep in range(args.max_timesteps):
time_cnt += 1
# 生成 ppo 模型的训练数据
experience_data = rlhf_model.make_experience(prompts, eos_token=0)
memory.append(experience_data)
# learn from the stored memories
if time_cnt % args.update_timesteps == 0:
if trainer.global_rank == 0:
for n in rlhf_model.state_dict():
shape = rlhf_model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
train_data = PPODataset(memory)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
trainer.fit(rlhf_model, data_loader)
print('rlhf training complete')
......@@ -224,8 +224,8 @@ if __name__ == "__main__":
import torch
from tqdm import tqdm
from src.trainer import rm_train_callback
from src.rlhf.reward import RewardModel
from src.dataset import RMDataset
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册