提交 e7dc79af 编写于 作者: U u010280923

opt ppo model

上级 65604ada
...@@ -74,7 +74,18 @@ python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \ ...@@ -74,7 +74,18 @@ python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \
``` ```
### 接入RLHF(Reinforcement Learning with Human Feedback) ### PPO Model (Reinforcement learning from Human Feedback)
```
python train_rm.py --load_sft_model "rwkv-190.pth" --load_rm_model "rm-6.pth" --wandb "" \
--proj_dir "out_rlhf" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
--my_qa_mask 1
```
...@@ -324,3 +324,46 @@ class RMDataset(Dataset): ...@@ -324,3 +324,46 @@ class RMDataset(Dataset):
m_a = torch.tensor(prompt_alter_mask, dtype=torch.long) m_a = torch.tensor(prompt_alter_mask, dtype=torch.long)
return x_p, x_a, m_p, m_a return x_p, x_a, m_p, m_a
class PPODataset(Dataset):
def __init__(self, memory):
self.data = memory
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# todo(luxin) 是否需要 padding ???
sequence, \
prompt_mask, \
mask, \
action_prob, \
action_log_prob, \
reward, \
value = self.data[index]
return sequence, prompt_mask, mask, action_prob, action_log_prob, reward, value
def load_prompt_data_4_ppo(args):
prompt_token_ids = []
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
tokenizer = TOKENIZER(WORD_NAME)
pf = pd.read_csv(args.data_file)
for index, row in pf.iterrows():
prompt = row["prompt"]
prompt_token_ids.append(tokenizer.tokenizer.encode(prompt))
prompt_token_ids = torch.tensor(prompt_token_ids, dtype=torch.long)
return prompt_token_ids
...@@ -12,6 +12,15 @@ from pytorch_lightning.strategies import DeepSpeedStrategy ...@@ -12,6 +12,15 @@ from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from tqdm import tqdm
from einops import pack
from einops import unpack
from src.rlhf.utils import exists
from src.rlhf.utils import gumbel_sample
from src.rlhf.utils import top_k
from src.rlhf.utils import identity
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
try: try:
...@@ -429,7 +438,7 @@ class RWKV(pl.LightningModule): ...@@ -429,7 +438,7 @@ class RWKV(pl.LightningModule):
return cfg.get("offload_optimizer") or cfg.get("offload_param") return cfg.get("offload_optimizer") or cfg.get("offload_param")
return False return False
def forward(self, idx, extra_embed=None, rm_train=False): def forward(self, idx, extra_embed=None, rm_train=False, ppo_train=False):
args = self.args args = self.args
B, T = idx.size() B, T = idx.size()
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
...@@ -454,15 +463,15 @@ class RWKV(pl.LightningModule): ...@@ -454,15 +463,15 @@ class RWKV(pl.LightningModule):
else: else:
x = block(x) x = block(x)
x = self.ln_out(x) embeds = self.ln_out(x)
# 用于 RM 模型的编码 # 用于 RM 模型的编码
if rm_train is True: if rm_train is True:
return x return embeds
if args.head_qk > 0: if args.head_qk > 0:
q = self.head_q(x)[:, :T, :] q = self.head_q(embeds)[:, :T, :]
k = self.head_k(x)[:, :T, :] k = self.head_k(embeds)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
...@@ -473,11 +482,66 @@ class RWKV(pl.LightningModule): ...@@ -473,11 +482,66 @@ class RWKV(pl.LightningModule):
elif os.environ["RWKV_FLOAT_MODE"] == "bf16": elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
x = self.head(x) + c logits = self.head(embeds) + c
else: else:
x = self.head(x) logits = self.head(embeds)
return x # 用于 PPO 模型
if ppo_train is True:
return logits, embeds
return logits
@torch.no_grad()
def generate(
self,
seq_len,
prompt = None,
temperature = 1.,
filter_logits_fn = top_k,
filter_thres = 0.9,
pad_value = 0.,
eos_token = None,
return_seq_without_prompt = True,
use_tqdm = False,
**kwargs
):
'''
'''
prompt, leading_dims = pack([prompt], '* n')
n, out = prompt.shape[-1], prompt.clone()
wrapper_fn = identity if not use_tqdm else tqdm
sample_num_times = max(1, seq_len - prompt.shape[-1])
for _ in wrapper_fn(range(sample_num_times)):
logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs)
logits, embeds = logits[:, -1], embeds[:, -1]
if exists(filter_logits_fn):
logits = filter_logits_fn(logits, thres = filter_thres)
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
out, _ = pack([out, sample], 'b *')
if exists(eos_token):
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, pad_value)
break
out, = unpack(out, leading_dims, '* n')
if not return_seq_without_prompt:
return out
return out[..., n:]
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
args = self.args args = self.args
......
...@@ -17,6 +17,8 @@ from torch.optim import Adam ...@@ -17,6 +17,8 @@ from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from pytorch_lightning.utilities import rank_zero_info
from einops import rearrange, repeat from einops import rearrange, repeat
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
...@@ -27,7 +29,7 @@ from src.rlhf.utils import masked_mean, eval_decorator ...@@ -27,7 +29,7 @@ from src.rlhf.utils import masked_mean, eval_decorator
from accelerate import Accelerator from accelerate import Accelerator
# actor critic - PaLM with lora # actor critic - rwkv with lora
PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
'actions', 'actions',
...@@ -43,39 +45,16 @@ class ActorCritic(nn.Module): ...@@ -43,39 +45,16 @@ class ActorCritic(nn.Module):
def __init__( def __init__(
self, self,
rwkv: RWKV, rwkv: RWKV,
critic_palm: Optional[RWKV] = None, critic: Optional[RWKV] = None,
pooled_values = False, pooled_values = False
actor_lora = True,
critic_lora = True,
actor_lora_r = 8,
critic_lora_r = 8,
actor_lora_scope = 'actor',
critic_lora_scope = 'critic',
actor_dropout = 0.,
critic_dropout = 0.
): ):
super().__init__() super().__init__()
self.actor_palm = rwkv self.actor = rwkv
self.critic_palm = critic_palm
if not exists(self.critic_palm):
self.critic_palm = copy.deepcopy(rwkv)
self.actor_palm.set_dropout(actor_dropout)
self.critic_palm.set_dropout(critic_dropout)
self.actor_lora = actor_lora self.critic = critic
self.critic_lora = critic_lora
self.actor_lora_scope = actor_lora_scope if actor_lora else None if not exists(self.critic):
self.critic_lora_scope = critic_lora_scope if critic_lora else None self.critic = copy.deepcopy(rwkv)
if self.actor_lora:
self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)
if self.critic_lora:
self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)
self.pooled_values = pooled_values self.pooled_values = pooled_values
self.value_head = nn.Sequential( self.value_head = nn.Sequential(
...@@ -86,23 +65,6 @@ class ActorCritic(nn.Module): ...@@ -86,23 +65,6 @@ class ActorCritic(nn.Module):
nn.init.zeros_(self.value_head[0].bias) nn.init.zeros_(self.value_head[0].bias)
nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2)) nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))
def actor_parameters(self):
if not self.actor_lora:
return self.actor_palm.parameters()
return [
*self.actor_palm.finetune_parameters(self.actor_lora_scope)
]
def critic_parameters(self):
if not self.actor_lora:
return [*self.critic_palm.parameters(), *self.value_head.parameters()]
return [
*self.critic_palm.finetune_parameters(self.critic_lora_scope),
*self.value_head.parameters()
]
@torch.no_grad() @torch.no_grad()
@eval_decorator @eval_decorator
def generate( def generate(
...@@ -113,7 +75,8 @@ class ActorCritic(nn.Module): ...@@ -113,7 +75,8 @@ class ActorCritic(nn.Module):
return_values = False, return_values = False,
**kwargs **kwargs
): ):
actions = self.actor_palm.generate( # 产生一条 response,相当于采取了一次 action
actions = self.actor.generate(
max_seq_len, max_seq_len,
prompt = state, prompt = state,
eos_token = eos_token, eos_token = eos_token,
...@@ -122,21 +85,26 @@ class ActorCritic(nn.Module): ...@@ -122,21 +85,26 @@ class ActorCritic(nn.Module):
**kwargs **kwargs
) )
# 将 prompt (state) 和 response (action) 进行拼接
sequence = torch.cat((state, actions), dim = -1) sequence = torch.cat((state, actions), dim = -1)
action_len = actions.shape[-1] action_len = actions.shape[-1]
state_len = state.shape[-1] state_len = state.shape[-1]
# 构建 prompt_mask (state_mask) 和 response_mask (action_mask)
prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len
prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0]) prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0])
action_mask = ~prompt_mask action_mask = ~prompt_mask
# 考虑 eos token
mask = None mask = None
if exists(eos_token): if exists(eos_token):
mask = ((sequence == eos_token).cumsum(dim = -1) == 0) mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
mask = F.pad(mask, (1, -1), value = True) # include eos token mask = F.pad(mask, (1, -1), value = True) # include eos token
action_mask &= mask action_mask &= mask
# 将生成的 sequence 输入到 actor 中,得到 action_logits
# 将生成的 sequence 输入到 critic 中,得到 value
action_logits, value = self.forward( action_logits, value = self.forward(
sequence, sequence,
mask = action_mask, mask = action_mask,
...@@ -158,7 +126,7 @@ class ActorCritic(nn.Module): ...@@ -158,7 +126,7 @@ class ActorCritic(nn.Module):
mask = None, mask = None,
return_values = True return_values = True
): ):
action_logits = self.actor_palm( action_logits = self.actor(
x, x,
finetune_scope = self.actor_lora_scope finetune_scope = self.actor_lora_scope
) )
...@@ -166,7 +134,7 @@ class ActorCritic(nn.Module): ...@@ -166,7 +134,7 @@ class ActorCritic(nn.Module):
if not return_values: if not return_values:
return action_logits, None return action_logits, None
critic_embeds = self.critic_palm( critic_embeds = self.critic(
x, x,
return_only_embedding = True, return_only_embedding = True,
finetune_scope = self.critic_lora_scope finetune_scope = self.critic_lora_scope
...@@ -278,119 +246,63 @@ def clipped_value_loss(values, rewards, old_values, clip): ...@@ -278,119 +246,63 @@ def clipped_value_loss(values, rewards, old_values, clip):
value_loss_2 = (values.flatten() - rewards) ** 2 value_loss_2 = (values.flatten() - rewards) ** 2
return torch.mean(torch.max(value_loss_1, value_loss_2)) return torch.mean(torch.max(value_loss_1, value_loss_2))
# rlhf trainer # rlhf
@beartype @beartype
class RLHFTrainer(nn.Module): class RLHF(nn.Module):
def __init__( def __init__(
self, self,
*, args,
prompts: Optional[List[str]] = None, accelerate_kwargs: dict = {}
prompts_path: Optional[str] = None,
prompt_token_ids: Optional[torch.Tensor] = None,
tokenizer: Callable = None,
rwkv: RWKV,
reward_model: RewardModel,
actor_critic: Optional[ActorCritic] = None,
actor_lr = 1e-4,
critic_lr = 1e-4,
actor_wd = 0.,
critic_wd = 0.,
actor_adam_eps = 1e-7,
critic_adam_eps = 1e-7,
actor_lora = True,
critic_lora = True,
actor_lora_r = 8,
critic_lora_r = 8,
critic_pooled_values = True,
actor_dropout = 0.,
critic_dropout = 0.,
betas = (0.9, 0.999),
max_norm = None,
eps_clip = 0.2,
value_clip = 0.4,
beta_s = .01,
pad_value = 0.,
minibatch_size = 16,
epochs = 1,
kl_div_loss_weight = 0.1, # between old action probs and new action probs - not sure what the right value is
accelerate_kwargs: dict = {},
use_lion = False
): ):
super().__init__() super().__init__()
self.accelerate = Accelerator(**accelerate_kwargs) self.args = args
# take care of prompts -> token ids self.accelerate = Accelerator(**accelerate_kwargs)
assert (exists(prompts) + exists(prompts_path) + exists(prompt_token_ids)) == 1 # 加载 RWKV 模型
rwkv = RWKV(args)
if exists(prompts_path): if len(args.load_sft_model) == 0:
path = Path(prompts_path) rank_zero_info(f"SFT must load model, please input ")
prompts = path.read_text().split('\n') exit(1)
if exists(prompts): rank_zero_info(f"########## Loading {args.load_sft_model}... ##########")
assert len(prompts) > 0, 'no prompts' try:
assert exists(tokenizer), 'tokenizer must be passed in if raw text prompts are given' load_dict = torch.load(args.load_sft_model, map_location="cpu")
prompt_token_ids = tokenizer(prompts) except:
rank_zero_info(f"Bad checkpoint {args.load_sft_model}")
exit(1)
self.pad_value = pad_value # token pad value if args.load_partial == 1:
self.num_prompts = prompt_token_ids.shape[0] load_keys = load_dict.keys()
self.register_buffer('prompt_token_ids', prompt_token_ids) for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
# models
self.rwkv = rwkv self.rwkv = rwkv
if not exists(actor_critic): # 使用 RWKV 初始化 actor_critic
actor_critic = ActorCritic( actor_critic = ActorCritic(
rwkv = rwkv, rwkv = self.rwkv,
actor_lora = actor_lora, actor_lora = args.actor_lora,
critic_lora = critic_lora, critic_lora = args.critic_lora,
actor_lora_r = actor_lora_r, actor_lora_r = args.actor_lora_r,
critic_lora_r = critic_lora_r, critic_lora_r = args.critic_lora_r,
pooled_values = critic_pooled_values, pooled_values = args.critic_pooled_values,
actor_dropout = actor_dropout, actor_dropout = args.actor_dropout,
critic_dropout = critic_dropout critic_dropout = args.critic_dropout
).to(rwkv.device) ).to(self.rwkv.device)
self.actor_critic = actor_critic self.actor_critic = actor_critic
# 加载 reward_model,并将 reward_model 设置为 evaluation 模式
reward_model = RewardModel(args)
reward_model.load(args.load_rm_model)
self.reward_model = reward_model.eval() self.reward_model = reward_model.eval()
# train hyperparameters
self.epochs = epochs
self.minibatch_size = minibatch_size
self.max_norm = max_norm
self.kl_div_loss_weight = kl_div_loss_weight
# optimizers
self.actor_optim = get_optimizer(actor_critic.actor_parameters(), lr = actor_lr, wd = actor_wd, betas = betas, eps = actor_adam_eps, use_lion = use_lion)
self.critic_optim = get_optimizer(actor_critic.critic_parameters(), lr = critic_lr, wd = critic_wd, betas = betas, eps = critic_adam_eps, use_lion = use_lion)
# ppo hyperparams
self.eps_clip = eps_clip
self.value_clip = value_clip
self.beta_s = beta_s
# prepare with accelerator
(
self.actor_critic,
self.reward_model,
self.actor_optim,
self.critic_optim
) = self.accelerate.prepare(
self.actor_critic,
self.reward_model,
self.actor_optim,
self.critic_optim
)
def print(self, msg): def print(self, msg):
return self.accelerate.print(msg) return self.accelerate.print(msg)
...@@ -414,6 +326,9 @@ class RLHFTrainer(nn.Module): ...@@ -414,6 +326,9 @@ class RLHFTrainer(nn.Module):
num_samples = 4, # sample 4 per prompt and select the one with highest reward num_samples = 4, # sample 4 per prompt and select the one with highest reward
**kwargs **kwargs
): ):
''' 未参与训练,仅推理时使用
'''
assert prompt.ndim == 1, 'only one prompt allowed at a time for now' assert prompt.ndim == 1, 'only one prompt allowed at a time for now'
prompt = repeat(prompt, 'n -> b n', b = num_samples) prompt = repeat(prompt, 'n -> b n', b = num_samples)
...@@ -451,32 +366,16 @@ class RLHFTrainer(nn.Module): ...@@ -451,32 +366,16 @@ class RLHFTrainer(nn.Module):
return best_sequence return best_sequence
def learn( def training_step(self, batch, batch_idx):
self, sequences, \
memories: Deque[Memory] prompt_masks, \
): masks, \
# stack all data stored in the memories old_action_probs, \
old_log_probs, \
all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories))) rewards, \
old_values = batch
# prepare dataloader for policy phase training
dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device)
self.actor_critic.train()
# PPO training # PPO training
for _ in range(self.epochs):
for (
sequences,
prompt_masks,
masks,
old_action_probs,
old_log_probs,
rewards,
old_values
) in dl:
action_masks = ~prompt_masks & masks action_masks = ~prompt_masks & masks
action_logits, values = self.actor_critic( action_logits, values = self.actor_critic(
...@@ -499,8 +398,8 @@ class RLHFTrainer(nn.Module): ...@@ -499,8 +398,8 @@ class RLHFTrainer(nn.Module):
kl_div_loss = 0. kl_div_loss = 0.
if self.kl_div_loss_weight > 0: if self.args.kl_div_loss_weight > 0:
kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.kl_div_loss_weight kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.args.kl_div_loss_weight
# handle non-pooled values # handle non-pooled values
...@@ -526,74 +425,41 @@ class RLHFTrainer(nn.Module): ...@@ -526,74 +425,41 @@ class RLHFTrainer(nn.Module):
advantages = rearrange(advantages, 'b -> b 1') advantages = rearrange(advantages, 'b -> b 1')
surr1 = ratios * advantages surr1 = ratios * advantages
surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages surr2 = ratios.clamp(1 - self.args.eps_clip, 1 + self.args.eps_clip) * advantages
policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropies policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies
# combine losses
loss = policy_loss.mean() + kl_div_loss
# update actor
self.accelerate.backward(loss)
self.print(f'policy_loss: {loss.item():.3f}') # actor loss (也称为 policy loss, 是最终要使用模型的 loss)
actor_loss = policy_loss.mean() + kl_div_loss
if exists(self.max_norm): # critic loss (也称为 value loss)
self.accelerator.clip_grad_norm_(self.actor_critic.actor_parameters(), self.max_norm) # update value network separate from policy network
critic_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip)
critic_loss = critic_loss.mean()
self.actor_optim.step() return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
self.actor_optim.zero_grad()
# calculate value loss and update value network separate from policy network
value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip) def make_experience(self, prompts, eos_token=None, temperature=1):
value_loss = value_loss.mean() ''' 通过与 environment 交互产生训练数据
'''
self.print(f'critic_loss: {value_loss.item():.3f}')
self.accelerate.backward(value_loss)
if exists(self.max_norm):
self.accelerator.clip_grad_norm_(self.actor_critic.critic_parameters(), self.max_norm)
self.critic_optim.step()
self.critic_optim.zero_grad()
def train(
self,
num_episodes = 50000,
max_timesteps = 500,
update_timesteps = 5000,
max_batch_size = 16,
max_seq_len = 2048,
eos_token = None,
temperature = 1.
):
device = self.device device = self.device
time = 0
memories = deque([])
for eps in tqdm(range(num_episodes), desc = 'episodes'):
for timestep in range(max_timesteps):
time += 1
# select a bunch of random states (prompts) # select a bunch of random states (prompts)
# and get the action (sampled sequence from palm as well as the action probs) # and get the action (sampled sequence from rwkv as well as the action probs)
# also calculate the reward using reward model and store # also calculate the reward using reward model and store
# 随机挑选一条 prompt
rand_prompt_index = randrange(0, self.num_prompts) rand_prompt_index = randrange(0, self.num_prompts)
state = self.prompt_token_ids[rand_prompt_index] state = self.prompt_token_ids[rand_prompt_index]
# remove padding from state # remove padding from state
state_mask = state != self.args.pad_value
state_mask = state != self.pad_value
state = state[state_mask] state = state[state_mask]
# get predicted sequence # get predicted sequence
# 与 environment 进行交互,其中返回的:
# action 是 response,
# sequence 是 prompt + response,
( (
actions, actions,
sequence, sequence,
...@@ -603,7 +469,7 @@ class RLHFTrainer(nn.Module): ...@@ -603,7 +469,7 @@ class RLHFTrainer(nn.Module):
value value
) = self.actor_critic.generate( ) = self.actor_critic.generate(
rearrange(state, 'n -> 1 n'), rearrange(state, 'n -> 1 n'),
max_seq_len = max_seq_len, max_seq_len = self.args.ctx_len,
eos_token = eos_token, eos_token = eos_token,
temperature = temperature, temperature = temperature,
return_values = True return_values = True
...@@ -619,7 +485,6 @@ class RLHFTrainer(nn.Module): ...@@ -619,7 +485,6 @@ class RLHFTrainer(nn.Module):
actions = rearrange(actions, '1 ... -> ...') actions = rearrange(actions, '1 ... -> ...')
# get reward as given by supervised trained reward model # get reward as given by supervised trained reward model
sequence = torch.cat((state, actions), dim = 0) sequence = torch.cat((state, actions), dim = 0)
prompt_length = len(state) prompt_length = len(state)
...@@ -636,11 +501,7 @@ class RLHFTrainer(nn.Module): ...@@ -636,11 +501,7 @@ class RLHFTrainer(nn.Module):
sample = True sample = True
) )
detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...') return (
# store memory for learning
memories.append(Memory(*map(detach_to_cpu_, (
sequence, sequence,
prompt_mask, prompt_mask,
mask, mask,
...@@ -648,12 +509,4 @@ class RLHFTrainer(nn.Module): ...@@ -648,12 +509,4 @@ class RLHFTrainer(nn.Module):
action_log_prob, action_log_prob,
reward, reward,
value value
)))) )
# learn from the stored memories
if time % update_timesteps == 0:
self.learn(memories)
memories.clear()
print('rlhf training complete')
import math
from pathlib import Path
import copy
from tqdm import tqdm
from functools import partial
from collections import deque, namedtuple
from random import randrange
from beartype import beartype
from beartype.typing import List, Optional, Callable, Deque
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from pytorch_lightning.utilities import rank_zero_info
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from src.model import RWKV
from src.rlhf.reward import RewardModel
from src.rlhf.optimizer import get_optimizer
from src.rlhf.utils import masked_mean, eval_decorator
from accelerate import Accelerator
# actor critic - PaLM with lora
PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
'actions',
'sequence',
'mask',
'prompt_mask',
'action_logits',
'values'
])
@beartype
class ActorCritic(nn.Module):
def __init__(
self,
rwkv: RWKV,
critic_palm: Optional[RWKV] = None,
pooled_values = False,
actor_lora = True,
critic_lora = True,
actor_lora_r = 8,
critic_lora_r = 8,
actor_lora_scope = 'actor',
critic_lora_scope = 'critic',
actor_dropout = 0.,
critic_dropout = 0.
):
super().__init__()
self.actor_palm = rwkv
self.critic_palm = critic_palm
if not exists(self.critic_palm):
self.critic_palm = copy.deepcopy(rwkv)
self.actor_palm.set_dropout(actor_dropout)
self.critic_palm.set_dropout(critic_dropout)
self.actor_lora = actor_lora
self.critic_lora = critic_lora
self.actor_lora_scope = actor_lora_scope if actor_lora else None
self.critic_lora_scope = critic_lora_scope if critic_lora else None
if self.actor_lora:
self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)
if self.critic_lora:
self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)
self.pooled_values = pooled_values
self.value_head = nn.Sequential(
nn.Linear(rwkv.dim, 1),
Rearrange('... 1 -> ...')
)
nn.init.zeros_(self.value_head[0].bias)
nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))
def actor_parameters(self):
if not self.actor_lora:
return self.actor_palm.parameters()
return [
*self.actor_palm.finetune_parameters(self.actor_lora_scope)
]
def critic_parameters(self):
if not self.actor_lora:
return [*self.critic_palm.parameters(), *self.value_head.parameters()]
return [
*self.critic_palm.finetune_parameters(self.critic_lora_scope),
*self.value_head.parameters()
]
@torch.no_grad()
@eval_decorator
def generate(
self,
state,
max_seq_len,
eos_token = None,
return_values = False,
**kwargs
):
actions = self.actor_palm.generate(
max_seq_len,
prompt = state,
eos_token = eos_token,
finetune_scope = self.actor_lora_scope,
use_tqdm = True,
**kwargs
)
sequence = torch.cat((state, actions), dim = -1)
action_len = actions.shape[-1]
state_len = state.shape[-1]
prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len
prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0])
action_mask = ~prompt_mask
mask = None
if exists(eos_token):
mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
mask = F.pad(mask, (1, -1), value = True) # include eos token
action_mask &= mask
action_logits, value = self.forward(
sequence,
mask = action_mask,
return_values = return_values
)
return PPOActionCriticReturn(
actions,
sequence,
mask,
prompt_mask,
action_logits,
value
)
def forward(
self,
x,
mask = None,
return_values = True
):
action_logits = self.actor_palm(
x,
finetune_scope = self.actor_lora_scope
)
if not return_values:
return action_logits, None
critic_embeds = self.critic_palm(
x,
return_only_embedding = True,
finetune_scope = self.critic_lora_scope
)
if self.pooled_values:
critic_embeds = shift(critic_embeds, shift = 1, dim = -2)
critic_embeds = masked_mean(critic_embeds, mask, dim = 1)
values = self.value_head(critic_embeds)
return action_logits, values
# data
Memory = namedtuple('Memory', [
'sequence',
'prompt_mask',
'mask',
'action_prob',
'action_log_prob',
'reward',
'value'
])
@beartype
class ExperienceDataset(Dataset):
def __init__(
self,
data: List[torch.Tensor],
device = None
):
super().__init__()
self.data = data
self.device = device
def __len__(self):
return self.data[0].shape[0]
def __getitem__(self, ind):
return tuple(map(lambda t: t[ind].to(self.device), self.data))
def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs):
ds = ExperienceDataset(data, device = device)
return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs)
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def masked_normalize(t, eps = 1e-5, mask = None, dim = None):
dim = default(dim, tuple(range(t.ndim)))
kwargs = dict(dim = dim, keepdim = True)
mean = masked_mean(t, mask = mask, **kwargs)
mean_centered = t - mean
var = masked_mean(mean_centered ** 2, mask = mask, **kwargs)
return mean_centered * var.clamp(min = eps).rsqrt()
def pad_sequence_fixed(sequences, *args, **kwargs):
first_el = sequences[0]
has_no_dimension = first_el.ndim == 0
# if no dimensions, add a single dimension
if has_no_dimension:
sequences = tuple(map(lambda t: t[None], sequences))
out = pad_sequence(sequences, *args, **kwargs)
if has_no_dimension:
out = rearrange(out, '... 1 -> ...')
return out
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def log_prob(prob, indices):
assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match'
return log(prob.gather(-1, indices[..., None])).squeeze(-1)
def shift(t, value = 0, shift = 1, dim = -1):
zeros = (0, 0) * (-dim - 1)
return F.pad(t, (*zeros, shift, -shift), value = value)
def masked_entropy(prob, dim = -1, mask = None):
entropies = (prob * log(prob)).sum(dim = -1)
return masked_mean(entropies, mask = mask).mean()
def masked_kl_div(prob1, prob2, mask = None):
"""
need to account for variable sequence lengths, therefore not using the built-in functional version
"""
kl_divs = (prob1 * (log(prob2) - log(prob1))).sum(dim = -1)
if not exists(mask):
return kl_divs.mean()
return masked_mean(kl_divs, mask).mean()
def clipped_value_loss(values, rewards, old_values, clip):
value_clipped = old_values + (values - old_values).clamp(-clip, clip)
value_loss_1 = (value_clipped.flatten() - rewards) ** 2
value_loss_2 = (values.flatten() - rewards) ** 2
return torch.mean(torch.max(value_loss_1, value_loss_2))
# rlhf trainer
@beartype
class RLHFTrainer(nn.Module):
def __init__(
self,
args,
accelerate_kwargs: dict = {}
):
super().__init__()
self.args = args
self.accelerate = Accelerator(**accelerate_kwargs)
# 加载 RWKV 模型
rwkv = RWKV(args)
if len(args.load_sft_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
rank_zero_info(f"########## Loading {args.load_sft_model}... ##########")
try:
load_dict = torch.load(args.load_sft_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_sft_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
self.rwkv = rwkv
# 使用 RWKV 初始化 actor_critic
actor_critic = ActorCritic(
rwkv = self.rwkv,
actor_lora = args.actor_lora,
critic_lora = args.critic_lora,
actor_lora_r = args.actor_lora_r,
critic_lora_r = args.critic_lora_r,
pooled_values = args.critic_pooled_values,
actor_dropout = args.actor_dropout,
critic_dropout = args.critic_dropout
).to(self.rwkv.device)
self.actor_critic = actor_critic
# 加载 reward_model,并将 reward_model 设置为 evaluation 模式
reward_model = RewardModel(args)
reward_model.load(args.load_rm_model)
self.reward_model = reward_model.eval()
# optimizers
self.actor_optim = get_optimizer(actor_critic.actor_parameters(), lr = self.args.actor_lr, wd = self.args.actor_wd, betas = self.args.betas, eps = self.args.actor_adam_eps, use_lion = self.args.use_lion)
self.critic_optim = get_optimizer(actor_critic.critic_parameters(), lr = self.args.critic_lr, wd = self.args.critic_wd, betas = self.args.betas, eps = self.args.critic_adam_eps, use_lion = self.args.use_lion)
# prepare with accelerator
(
self.actor_critic,
self.reward_model,
self.actor_optim,
self.critic_optim
) = self.accelerate.prepare(
self.actor_critic,
self.reward_model,
self.actor_optim,
self.critic_optim
)
def print(self, msg):
return self.accelerate.print(msg)
def save(self, filepath = './checkpoint.pt'):
torch.save(self.actor_critic.state_dict(), filepath)
def load(self, filepath = './checkpoint.pt'):
state_dict = torch.load(filepath)
self.actor_critic.load_state_dict(state_dict)
@property
def device(self):
return self.accelerate.device
@torch.no_grad()
def generate(
self,
max_seq_len,
*args,
prompt,
num_samples = 4, # sample 4 per prompt and select the one with highest reward
**kwargs
):
assert prompt.ndim == 1, 'only one prompt allowed at a time for now'
prompt = repeat(prompt, 'n -> b n', b = num_samples)
actor_critic = self.accelerate.unwrap_model(self.actor_critic)
reward_model = self.accelerate.unwrap_model(self.reward_model)
actor_critic.eval()
(
actions,
sequences,
mask,
prompt_mask,
action_logits,
_
) = actor_critic.generate(
prompt,
*args,
max_seq_len = max_seq_len,
return_values = False,
**kwargs
)
rewards = reward_model(
sequences,
prompt_mask = prompt_mask,
mask = mask,
sample = True
)
best_sequence_index = rewards.topk(1, dim = -1).indices
best_sequence = sequences[best_sequence_index]
best_sequence = rearrange(best_sequence, '1 ... -> ...')
return best_sequence
def learn(
self,
memories: Deque[Memory]
):
# stack all data stored in the memories
all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories)))
# prepare dataloader for policy phase training
dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device)
self.actor_critic.train()
# PPO training
for _ in range(self.epochs):
for (
sequences,
prompt_masks,
masks,
old_action_probs,
old_log_probs,
rewards,
old_values
) in dl:
action_masks = ~prompt_masks & masks
action_logits, values = self.actor_critic(
sequences,
mask = action_masks
)
action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_len = old_log_probs.shape[-1]
action_probs = action_logits.softmax(dim = -1)
action_log_probs = log_prob(action_probs, sequences)
action_log_probs = action_log_probs[:, -action_len:]
# calculate entropies, taking into account which part of the sequence is actually an action
entropies = masked_entropy(action_probs, mask = action_masks)
# calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not
kl_div_loss = 0.
if self.args.kl_div_loss_weight > 0:
kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.args.kl_div_loss_weight
# handle non-pooled values
normalize_kwargs = dict()
if old_values.ndim == 2:
old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values))
old_values = old_values[:, -action_len:]
values = values[:, -action_len:]
rewards = rearrange(rewards, 'b -> b 1')
normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:])
if values.ndim < rewards.ndim:
values = rearrange(values, '... -> ... 1')
# calculate clipped surrogate objective, classic PPO loss
ratios = (action_log_probs - old_log_probs).exp()
advantages = masked_normalize(rewards - old_values, **normalize_kwargs)
if advantages.ndim == 1:
advantages = rearrange(advantages, 'b -> b 1')
surr1 = ratios * advantages
surr2 = ratios.clamp(1 - self.args.eps_clip, 1 + self.args.eps_clip) * advantages
policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies
# combine losses
loss = policy_loss.mean() + kl_div_loss
# update actor
self.accelerate.backward(loss)
self.print(f'policy_loss: {loss.item():.3f}')
if exists(self.args.max_norm):
self.accelerator.clip_grad_norm_(self.actor_critic.actor_parameters(), self.args.max_norm)
self.actor_optim.step()
self.actor_optim.zero_grad()
# calculate value loss and update value network separate from policy network
value_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip)
value_loss = value_loss.mean()
self.print(f'critic_loss: {value_loss.item():.3f}')
self.accelerate.backward(value_loss)
if exists(self.args.max_norm):
self.accelerator.clip_grad_norm_(self.actor_critic.critic_parameters(), self.args.max_norm)
self.critic_optim.step()
self.critic_optim.zero_grad()
def train(
self,
num_episodes = 50000,
max_timesteps = 500,
update_timesteps = 5000,
max_batch_size = 16,
eos_token = None,
temperature = 1.
):
device = self.device
time = 0
memories = deque([])
for eps in tqdm(range(num_episodes), desc = 'episodes'):
for timestep in range(max_timesteps):
time += 1
# select a bunch of random states (prompts)
# and get the action (sampled sequence from palm as well as the action probs)
# also calculate the reward using reward model and store
rand_prompt_index = randrange(0, self.num_prompts)
state = self.prompt_token_ids[rand_prompt_index]
# remove padding from state
state_mask = state != self.args.pad_value
state = state[state_mask]
# get predicted sequence
(
actions,
sequence,
mask,
prompt_mask,
action_logits,
value
) = self.actor_critic.generate(
rearrange(state, 'n -> 1 n'),
max_seq_len = self.args.ctx_len,
eos_token = eos_token,
temperature = temperature,
return_values = True
)
action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_prob = action_logits.softmax(dim = -1)
action_len = actions.shape[-1]
action_log_prob = log_prob(action_prob, sequence)
action_log_prob = action_log_prob[:, -action_len:]
actions = rearrange(actions, '1 ... -> ...')
# get reward as given by supervised trained reward model
sequence = torch.cat((state, actions), dim = 0)
prompt_length = len(state)
prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length
sequence = rearrange(sequence, 'n -> 1 n')
prompt_mask = rearrange(prompt_mask, 'n -> 1 n')
mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device)
reward = self.reward_model(
sequence,
prompt_mask = prompt_mask,
mask = mask,
sample = True
)
detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...')
# store memory for learning
memories.append(Memory(*map(detach_to_cpu_, (
sequence,
prompt_mask,
mask,
action_prob,
action_log_prob,
reward,
value
))))
# learn from the stored memories
if time % update_timesteps == 0:
self.learn(memories)
memories.clear()
print('rlhf training complete')
...@@ -8,6 +8,9 @@ from einops import rearrange ...@@ -8,6 +8,9 @@ from einops import rearrange
def exists(val): def exists(val):
return val is not None return val is not None
def identity(t, *args, **kwargs):
return t
# decorators # decorators
def eval_decorator(fn): def eval_decorator(fn):
......
...@@ -287,6 +287,142 @@ class rm_train_callback(pl.Callback): ...@@ -287,6 +287,142 @@ class rm_train_callback(pl.Callback):
trainer.my_loss_count = 0 trainer.my_loss_count = 0
class rlhf_train_callback(pl.Callback):
def __init__(self, args):
super().__init__()
self.args = args
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
args = self.args
# if args.cuda_cleanup > 0:
# torch.cuda.empty_cache()
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
# LR schedule
w_step = args.warmup_steps
if args.lr_final == args.lr_init or args.epoch_count == 0:
lr = args.lr_init
else:
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
progress = (decay_step - w_step + 1) / (decay_total - w_step)
progress = min(1, max(0, progress))
if args.lr_final == 0 or args.lr_init == 0: # linear decay
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
else: # exp decay
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
if trainer.global_step < w_step:
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
# if trainer.is_global_zero:
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0:
param_group["lr"] = lr * param_group["my_lr_scale"]
# print(param_group["lr"], param_group["my_lr_scale"])
else:
param_group["lr"] = lr
trainer.my_lr = lr
# rank_zero_info(f"{real_step} {lr}")
if trainer.global_step == 0:
if trainer.is_global_zero: # logging
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
try:
print(f"\n{trainer.strategy.config}\n")
trainer.my_log.write(f"{trainer.strategy.config}\n")
except:
pass
trainer.my_log.flush()
if len(args.wandb) > 0:
print("Login to wandb...")
import wandb
wandb.init(
project=args.wandb,
name=args.run_name + " " + args.my_timestamp,
config=args,
save_code=False,
)
trainer.my_wandb = wandb
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
args = self.args
if trainer.is_global_zero: # logging
t_now = time.time_ns()
token_per_step = args.ctx_len * args.real_bsz
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
kt_s = 0
try:
t_cost = (t_now - trainer.my_time_ns) / 1e9
kt_s = token_per_step / t_cost / 1000
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
except:
pass
trainer.my_time_ns = t_now
trainer.my_loss = trainer.my_loss_all.float().mean().item()
trainer.my_loss_sum += trainer.my_loss
trainer.my_loss_count += 1
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
# self.log("s", real_step, prog_bar=True, on_step=True)
if len(args.wandb) > 0:
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
if kt_s > 0:
lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step))
if args.magic_prime > 0:
if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
my_save(
to_save_dict,
f"{args.proj_dir}/rlhf-final.pth",
)
def on_train_epoch_start(self, trainer, pl_module):
args = self.args
dataset = trainer.train_dataloader.dataset.datasets
assert "RMDataset" in str(dataset)
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
def on_train_epoch_end(self, trainer, pl_module):
args = self.args
if trainer.is_global_zero: # logging & save state_dict
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
if args.data_type == 'wds_img':
raw_dict = pl_module.state_dict()
to_save_dict = {}
for k in raw_dict:
if k.startswith('encoder.') or k.startswith('decoder.'):
to_save_dict[k] = raw_dict[k]
else:
to_save_dict = pl_module.state_dict()
try:
my_save(
to_save_dict,
f"{args.proj_dir}/rlhf-{args.epoch_begin + trainer.current_epoch}.pth",
)
except Exception as e:
print('Error\n\n', e, '\n\n')
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush()
trainer.my_loss_sum = 0
trainer.my_loss_count = 0
@rank_zero_only @rank_zero_only
def generate_init_weight(model, init_weight_name): def generate_init_weight(model, init_weight_name):
mm = model.generate_init_weight() mm = model.generate_init_weight()
......
...@@ -6,38 +6,293 @@ ...@@ -6,38 +6,293 @@
''' '''
# here put the import lib # here put the import lib
import torch ########################################################################################################
from src.model import RWKV # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
from src.rlhf.reward import RewardModel ########################################################################################################
from src.rlhf.ppo import RLHFTrainer
if __name__ == "__main__":
# load your pretrained RWKV from argparse import ArgumentParser
# todo(luxin) 加载 SFT 之后的预训练模型 from pytorch_lightning import Trainer
rwkv_model = RWKV() from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
# palm.load('./path/to/pretrained/palm.pt')
rank_zero_info("########## work in progress ##########")
# load your pretrained reward model
# todo(luxin) 加载训练好的 reward Model ########################################################################################################
reward_model = RewardModel( #
rwkv_model, # example: train a simple L12-D768 RWKV on dummy data
num_binned_output = 5 #
) # python train.py --load_model "" --wandb "" --proj_dir "out" \
# reward_model.load('./path/to/pretrained/reward_model.pt') # --data_file "" --data_type "dummy" --vocab_size 0 \
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
# ready your list of prompts for reinforcement learning # --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
# todo(luxin) 读入 Prompts 数据集(此处的 Prompt 与 SFT、RM 阶段的 Prompt 要不一样) # --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
prompts = torch.randint(0, 256, (50000, 512)) # 50k prompts # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# pass it all to the trainer and train # example: train a simple L6-D512 RWKV from scratch on enwik8
# 训练 PPO 模型 #
trainer = RLHFTrainer( # python train.py --load_model "" --wandb "" --proj_dir "out" \
palm = palm, # --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
reward_model = reward_model, # --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
prompt_token_ids = prompts # --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
) # --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
trainer.train(num_episodes = 100) # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# then, if it succeeded... # example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
# generate say 10 samples and use the reward model to return the best one #
answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,) # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
print(answer) # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
\ No newline at end of file # --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser = ArgumentParser()
parser.add_argument("--load_sft_model", default="", type=str) # full path, with .pth
parser.add_argument("--load_rm_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--dim_att", default=0, type=int)
parser.add_argument("--dim_ffn", default=0, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int)
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser.add_argument("--my_img_version", default=0, type=str)
parser.add_argument("--my_img_size", default=0, type=int)
parser.add_argument("--my_img_bit", default=0, type=int)
parser.add_argument("--my_img_clip", default='x', type=str)
parser.add_argument("--my_img_clip_scale", default=1, type=float)
parser.add_argument("--my_img_l1_scale", default=0, type=float)
parser.add_argument("--my_img_encoder", default='x', type=str)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser.add_argument("--my_sample_len", default=0, type=int)
parser.add_argument("--my_ffn_shift", default=1, type=int)
parser.add_argument("--my_att_shift", default=1, type=int)
parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
# PPO model parameters
parser.add_argument("--critic_pooled_values", default=True, type=bool)
parser.add_argument("--max_norm", default=None, type=float)
parser.add_argument("--kl_div_loss_weight", default=0.1, type=float) # between old action probs and new action probs - not sure what the right value is
parser.add_argument("--eps_clip", default=0.2, type=float)
parser.add_argument("--value_clip", default=0.4, type=float)
parser.add_argument("--beta_s", default=0.01, type=float)
parser.add_argument("--actor_lr", default=1e-4, type=float)
parser.add_argument("--critic_lr", default=1e-4, type=float)
parser.add_argument("--actor_wd", default=0., type=float)
parser.add_argument("--critic_wd", default=0., type=float)
parser.add_argument("--actor_adam_eps", default=1e-7, type=float)
parser.add_argument("--critic_adam_eps", default=1e-7, type=float)
parser.add_argument("--pad_value", default=1, type=float) # token pad value
parser.add_argument("--use_lion", default=False, type=bool)
parser.add_argument("--num_episodes", default=50000, type=int)
parser.add_argument("--max_timesteps", default=500, type=int)
parser.add_argument("--update_timesteps", default=5000, type=int)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
########################################################################################################
import os, warnings, math, datetime, sys, time
import numpy as np
import torch
from torch.utils.data import DataLoader
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
# os.environ["WDS_SHOW_SEED"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False
args.replace_sampler_ddp = False
args.logger = False
args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0
args.check_val_every_n_epoch = int(1e20)
args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
os.environ["RWKV_MY_TESTING"] = args.my_testing
if args.dim_att <= 0:
args.dim_att = args.n_embd
if args.dim_ffn <= 0:
args.dim_ffn = args.n_embd * 4
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir)
samples_per_epoch = args.epoch_steps * args.real_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len
rank_zero_info(
f"""
############################################################################
#
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
#
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
#
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
#
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
#
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
#
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
#
############################################################################
"""
)
rank_zero_info(str(vars(args)) + "\n")
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32":
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
os.environ["RWKV_JIT_ON"] = "1"
if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT_ON"] = "0"
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
if args.precision == "fp32":
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if "32" in args.precision:
args.precision = 32
elif args.precision == "fp16":
args.precision = 16
else:
args.precision = "bf16"
########################################################################################################
from tqdm import tqdm
from collections import deque, namedtuple
from einops import rearrange
from src.dataset import PPODataset, load_prompt_data_4_ppo
from src.rlhf.ppo import RLHF
from src.trainer import rlhf_train_callback
# 用于 PPO 训练的数据,需要与 environment 交互获得
memory = []
# 读入训练数据集
prompts = load_prompt_data_4_ppo(args)
# PPO 模型
rlhf_model = RLHF(args)
# 模型训练
# trainer
trainer = Trainer.from_argparse_args(
args,
callbacks=[rlhf_train_callback(args)],
)
time_cnt = 0
for eps in tqdm(range(args.num_episodes), desc = 'episodes'):
for timestep in range(args.max_timesteps):
time_cnt += 1
# 生成 ppo 模型的训练数据
experience_data = rlhf_model.make_experience(prompts, eos_token=0)
memory.append(experience_data)
# learn from the stored memories
if time_cnt % args.update_timesteps == 0:
if trainer.global_rank == 0:
for n in rlhf_model.state_dict():
shape = rlhf_model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
train_data = PPODataset(memory)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
trainer.fit(rlhf_model, data_loader)
print('rlhf training complete')
...@@ -224,8 +224,8 @@ if __name__ == "__main__": ...@@ -224,8 +224,8 @@ if __name__ == "__main__":
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from src.trainer import rm_train_callback
from src.trainer import rm_train_callback
from src.rlhf.reward import RewardModel from src.rlhf.reward import RewardModel
from src.dataset import RMDataset from src.dataset import RMDataset
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册