提交 be4440da 编写于 作者: U u010280923

transfer ppo code to pytorch_lightning style

上级 2164e3e5
......@@ -12,6 +12,9 @@ from src.utils import TOKENIZER
from .binidx import MMapIndexedDataset
from .utils import MaybeIsPrime
from typing import Iterable, Callable
from torch.utils.data import IterableDataset
class MyDataset(Dataset):
def __init__(self, args):
......@@ -326,25 +329,14 @@ class RMDataset(Dataset):
return x_p, x_a, m_p, m_a
class PPODataset(Dataset):
def __init__(self, memory):
self.data = memory
class ExperienceDataset(IterableDataset):
def __init__(self, generate_batch: Callable):
super().__init__()
self.generate_batch = generate_batch
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# todo(luxin) 是否需要 padding ???
sequence, \
prompt_mask, \
mask, \
action_prob, \
action_log_prob, \
reward, \
value = self.data[index]
return sequence, prompt_mask, mask, action_prob, action_log_prob, reward, value
def __iter__(self) -> Iterable:
iterator = self.generate_batch()
return iterator
def load_prompt_data_4_ppo(args):
......@@ -356,14 +348,12 @@ def load_prompt_data_4_ppo(args):
] # [vocab, vocab] for Pile model
tokenizer = TOKENIZER(WORD_NAME)
ctx_len = args.ctx_len
req_len = ctx_len
pf = pd.read_csv(args.data_file)
for index, row in pf.iterrows():
prompt = row["prompt"]
prompt_idx = tokenizer.tokenizer.encode(prompt)
prompt_idx = prompt_idx[: req_len]
prompt_idx = prompt_idx[: args.ctx_len]
prompt_token_ids.append(
torch.tensor(prompt_idx, dtype=torch.long))
......
......@@ -508,7 +508,6 @@ class RWKV(pl.LightningModule):
filter_logits_fn = top_k,
filter_thres = 0.9,
pad_value = 0.,
eos_token = None,
return_seq_without_prompt = True
):
''' 生成 response,用于 ppo 模型的训练
......@@ -521,7 +520,7 @@ class RWKV(pl.LightningModule):
sample_num_times = max(1, seq_len - prompt.shape[-1])
for _ in tqdm(range(sample_num_times), desc="gen responses"):
pad_idx = torch.tensor([[eos_token] * (self.args.ctx_len - out.shape[-1])])
pad_idx = torch.tensor([[self.args.eos_token] * (self.args.ctx_len - out.shape[-1])])
query_idx = torch.cat((out, pad_idx), dim=-1)
logits, embeds = self.forward(query_idx, ppo_train=True)
logits, embeds = logits[:, -1], embeds[:, -1]
......@@ -532,8 +531,8 @@ class RWKV(pl.LightningModule):
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
out, _ = pack([out, sample], 'b *')
if exists(eos_token):
is_eos_tokens = (out == eos_token)
if exists(self.args.eos_token):
is_eos_tokens = (out == self.args.eos_token)
if is_eos_tokens.any(dim = -1).all():
# mask out everything after the eos tokens
......
......@@ -29,6 +29,8 @@ from src.model import RWKV
from src.rlhf.reward import RewardModel
from src.rlhf.optimizer import get_optimizer
from src.rlhf.utils import masked_mean, eval_decorator
from src.dataset import load_prompt_data_4_ppo
from src.dataset import ExperienceDataset
# actor critic
......@@ -52,12 +54,13 @@ class ActorCritic(nn.Module):
):
super().__init__()
self.args = args
self.actor = actor
self.critic = critic
self.pooled_values = pooled_values
self.value_head = nn.Sequential(
nn.Linear(args.n_embd, 1),
nn.Linear(self.args.n_embd, 1),
Rearrange('... 1 -> ...')
)
......@@ -70,14 +73,12 @@ class ActorCritic(nn.Module):
self,
state,
max_seq_len,
eos_token = None,
return_values = False
):
# 产生一条 response,相当于采取了一次 action
actions = self.actor.generate(
max_seq_len,
prompt = state,
eos_token = eos_token
prompt = state
)
# 将 prompt (state) 和 response (action) 进行拼接
......@@ -93,8 +94,8 @@ class ActorCritic(nn.Module):
# 考虑 eos token
mask = None
if exists(eos_token):
mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
if exists(self.args.eos_token):
mask = ((sequence == self.args.eos_token).cumsum(dim = -1) == 0)
mask = F.pad(mask, (1, -1), value = True) # include eos token
action_mask &= mask
......@@ -143,27 +144,6 @@ class ActorCritic(nn.Module):
return action_logits, values
@beartype
class ExperienceDataset(Dataset):
def __init__(
self,
data: List[torch.Tensor],
device = None
):
super().__init__()
self.data = data
self.device = device
def __len__(self):
return self.data[0].shape[0]
def __getitem__(self, ind):
return tuple(map(lambda t: t[ind].to(self.device), self.data))
def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs):
ds = ExperienceDataset(data, device = device)
return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs)
# helper functions
def exists(val):
......@@ -230,7 +210,6 @@ def clipped_value_loss(values, rewards, old_values, clip):
return torch.mean(torch.max(value_loss_1, value_loss_2))
# rlhf
@beartype
class RLHF(pl.LightningModule):
def __init__(
......@@ -244,6 +223,18 @@ class RLHF(pl.LightningModule):
self.args = args
# 读入 prompts 数据
self.prompts = load_prompt_data_4_ppo(args)
# 用于保存与 environment 的交互数据,用于训练 actor_critic (agent)
self.sequence_batch = []
self.prompt_mask_batch = []
self.mask_batch = []
self.action_prob_batch = []
self.action_log_prob_batch = []
self.reward_batch = []
self.value_batch = []
# 使用 RWKV 初始化 actor_critic
actor_critic = ActorCritic(
args=self.args,
......@@ -266,50 +257,105 @@ class RLHF(pl.LightningModule):
def configure_optimizers(self):
args = self.args
optim_groups_actor = []
optim_groups_critic = []
if args.layerwise_lr > 0:
lr_1x = set()
lr_2x = set()
lr_3x = set()
lr_1x_actor = set()
lr_2x_actor = set()
lr_3x_actor = set()
lr_1x_critic = set()
lr_2x_critic = set()
lr_3x_critic = set()
for n, p in self.named_parameters():
if "time_mix" in n:
if args.my_pile_stage == 2:
lr_2x.add(n)
if "actor" in n:
lr_2x_actor.add(n)
elif "critic" in n:
lr_2x_critic.add(n)
else:
lr_1x.add(n)
if "actor" in n:
lr_1x_actor.add(n)
elif "critic" in n:
lr_1x_critic.add(n)
elif "time_decay" in n:
if args.my_pile_stage == 2:
lr_3x.add(n)
if "actor" in n:
lr_3x_actor.add(n)
elif "critic" in n:
lr_3x_critic.add(n)
else:
lr_2x.add(n)
if "actor" in n:
lr_2x_actor.add(n)
elif "critic" in n:
lr_2x_critic.add(n)
elif "time_first" in n:
lr_3x.add(n)
if "actor" in n:
lr_3x_actor.add(n)
elif "critic" in n:
lr_3x_critic.add(n)
else:
lr_1x.add(n)
lr_1x = sorted(list(lr_1x))
lr_2x = sorted(list(lr_2x))
lr_3x = sorted(list(lr_3x))
if "actor" in n:
lr_1x_actor.add(n)
elif "critic" in n:
lr_1x_critic.add(n)
lr_1x_actor = sorted(list(lr_1x_actor))
lr_2x_actor = sorted(list(lr_2x_actor))
lr_3x_actor = sorted(list(lr_3x_actor))
lr_1x_critic = sorted(list(lr_1x_critic))
lr_2x_critic = sorted(list(lr_2x_critic))
lr_3x_critic = sorted(list(lr_3x_critic))
param_dict = {n: p for n, p in self.named_parameters()}
if args.my_pile_stage == 2:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
optim_groups_actor = [
{"params": [param_dict[n] for n in lr_1x_actor], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x_actor], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
{"params": [param_dict[n] for n in lr_3x_actor], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
]
optim_groups_critic = [
{"params": [param_dict[n] for n in lr_1x_critic], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x_critic], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
{"params": [param_dict[n] for n in lr_3x_critic], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
]
else:
optim_groups = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
optim_groups_actor = [
{"params": [param_dict[n] for n in lr_1x_actor], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x_actor], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x_actor], "weight_decay": 0.0, "my_lr_scale": 3.0},
]
optim_groups_critic = [
{"params": [param_dict[n] for n in lr_1x_critic], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x_critic], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x_critic], "weight_decay": 0.0, "my_lr_scale": 3.0},
]
else:
optim_groups = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
optim_groups_actor = [
{"params": [p for n, p in self.named_parameters() if "actor" in n], "weight_decay": 0.0},
]
optim_groups_critic = [
{"params": [p for n, p in self.named_parameters() if "critic" in n], "weight_decay": 0.0},
]
if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
actor_optimizer = DeepSpeedCPUAdam(optim_groups_actor, lr=self.args.actor_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
critic_optimizer = DeepSpeedCPUAdam(optim_groups_critic, lr=self.args.critic_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return actor_optimizer, critic_optimizer
actor_optimizer = FusedAdam(optim_groups_actor, lr=self.args.actor_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
critic_optimizer = FusedAdam(optim_groups_critic, lr=self.args.critic_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
return actor_optimizer, critic_optimizer
@property
def deepspeed_offload(self) -> bool:
......@@ -360,7 +406,7 @@ class RLHF(pl.LightningModule):
return best_sequence
def training_step(self, batch, batch_idx):
def training_step(self, batch, batch_idx, optimizer_idx):
sequences, \
prompt_masks, \
masks, \
......@@ -423,27 +469,34 @@ class RLHF(pl.LightningModule):
policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies
# actor loss (也称为 policy loss, 是最终要使用模型的 loss)
if optimizer_idx == 0:
actor_loss = policy_loss.mean() + kl_div_loss
return actor_loss
# critic loss (也称为 value loss)
# update value network separate from policy network
if optimizer_idx == 1:
critic_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip)
critic_loss = critic_loss.mean()
return critic_loss
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
def make_experience(self, prompts, eos_token=None, temperature=1):
def gen_experience_dataset(self):
''' 通过与 environment 交互产生训练数据
'''
device = self.device
time_cnt = 0
for eps in tqdm(range(self.args.num_episodes), desc = 'episodes'):
for timestep in range(self.args.max_timesteps):
time_cnt += 1
# select a bunch of random states (prompts)
# and get the action (sampled sequence from rwkv as well as the action probs)
# also calculate the reward using reward model and store
# 随机挑选一条 prompt
rand_prompt_index = randrange(0, len(prompts))
state = prompts[rand_prompt_index]
rand_prompt_index = randrange(0, len(self.prompts))
state = self.prompts[rand_prompt_index]
# remove padding from state
state_mask = state != self.args.pad_value
......@@ -463,7 +516,6 @@ class RLHF(pl.LightningModule):
) = self.actor_critic.generate(
rearrange(state, 'n -> 1 n'),
max_seq_len = self.args.ctx_len,
eos_token = eos_token,
return_values = True
)
action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
......@@ -493,12 +545,43 @@ class RLHF(pl.LightningModule):
sample = True
)
return (
sequence,
prompt_mask,
mask,
action_prob,
action_log_prob,
reward,
value
self.sequence_batch.append(sequence)
self.prompt_mask_batch.append(prompt_mask)
self.mask_batch.append(mask)
self.action_prob_batch.append(action_prob)
self.action_log_prob_batch.append(action_log_prob)
self.reward_batch.append(reward)
self.value_batch.append(value)
if time_cnt % self.args.update_timesteps == 0:
train_data = zip(
self.sequence_batch, self.prompt_mask_batch, self.mask_batch,
self.action_prob_batch, self.action_log_prob_batch, self.reward_batch,
self.value_batch
)
for _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value in train_data:
yield _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value
self.sequence_batch.clear()
self.prompt_mask_batch.clear()
self.mask_batch.clear()
self.action_prob_batch.clear()
self.action_log_prob_batch.clear()
self.reward_batch.clear()
self.value_batch.clear()
def _dataloader(self) -> DataLoader:
''' Initialize the Replay Buffer dataset used for retrieving experiences '''
dataset = ExperienceDataset(self.gen_experience_dataset)
dataloader = DataLoader(dataset=dataset, batch_size=self.args.micro_bsz)
return dataloader
def train_dataloader(self) -> DataLoader:
''' Get train loader '''
return self._dataloader()
......@@ -138,6 +138,7 @@ if __name__ == "__main__":
parser.add_argument("--num_episodes", default=50000, type=int)
parser.add_argument("--max_timesteps", default=500, type=int)
parser.add_argument("--update_timesteps", default=5000, type=int)
parser.add_argument("--eos_token", default=0, type=int)
parser = Trainer.add_argparse_args(parser)
......@@ -249,7 +250,6 @@ if __name__ == "__main__":
from collections import deque, namedtuple
from einops import rearrange
from src.dataset import PPODataset, load_prompt_data_4_ppo
from src.rlhf.ppo import RLHF
from src.trainer import rlhf_train_callback
from src.model import RWKV
......@@ -258,9 +258,6 @@ if __name__ == "__main__":
# 用于 PPO 训练的数据,需要与 environment 交互获得
memory = []
# 读入训练数据集
prompts = load_prompt_data_4_ppo(args)
# 用 rwkv 初始化 actor 模型
actor = RWKV(args)
actor.load(args.load_sft_model)
......@@ -298,21 +295,7 @@ if __name__ == "__main__":
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
time_cnt = 0
for eps in tqdm(range(args.num_episodes), desc = 'episodes'):
for timestep in range(args.max_timesteps):
time_cnt += 1
# 生成 ppo 模型的训练数据
experience_data = rlhf_model.make_experience(prompts, eos_token=0)
memory.append(experience_data)
# learn from the stored memories
if time_cnt % args.update_timesteps == 0:
train_data = PPODataset(memory)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
trainer.fit(rlhf_model, data_loader)
trainer.fit(rlhf_model)
print('rlhf training complete')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册