提交 be4440da 编写于 作者: U u010280923

transfer ppo code to pytorch_lightning style

上级 2164e3e5
...@@ -12,6 +12,9 @@ from src.utils import TOKENIZER ...@@ -12,6 +12,9 @@ from src.utils import TOKENIZER
from .binidx import MMapIndexedDataset from .binidx import MMapIndexedDataset
from .utils import MaybeIsPrime from .utils import MaybeIsPrime
from typing import Iterable, Callable
from torch.utils.data import IterableDataset
class MyDataset(Dataset): class MyDataset(Dataset):
def __init__(self, args): def __init__(self, args):
...@@ -326,25 +329,14 @@ class RMDataset(Dataset): ...@@ -326,25 +329,14 @@ class RMDataset(Dataset):
return x_p, x_a, m_p, m_a return x_p, x_a, m_p, m_a
class PPODataset(Dataset): class ExperienceDataset(IterableDataset):
def __init__(self, memory): def __init__(self, generate_batch: Callable):
self.data = memory super().__init__()
self.generate_batch = generate_batch
def __len__(self): def __iter__(self) -> Iterable:
return len(self.data) iterator = self.generate_batch()
return iterator
def __getitem__(self, index):
# todo(luxin) 是否需要 padding ???
sequence, \
prompt_mask, \
mask, \
action_prob, \
action_log_prob, \
reward, \
value = self.data[index]
return sequence, prompt_mask, mask, action_prob, action_log_prob, reward, value
def load_prompt_data_4_ppo(args): def load_prompt_data_4_ppo(args):
...@@ -356,14 +348,12 @@ def load_prompt_data_4_ppo(args): ...@@ -356,14 +348,12 @@ def load_prompt_data_4_ppo(args):
] # [vocab, vocab] for Pile model ] # [vocab, vocab] for Pile model
tokenizer = TOKENIZER(WORD_NAME) tokenizer = TOKENIZER(WORD_NAME)
ctx_len = args.ctx_len
req_len = ctx_len
pf = pd.read_csv(args.data_file) pf = pd.read_csv(args.data_file)
for index, row in pf.iterrows(): for index, row in pf.iterrows():
prompt = row["prompt"] prompt = row["prompt"]
prompt_idx = tokenizer.tokenizer.encode(prompt) prompt_idx = tokenizer.tokenizer.encode(prompt)
prompt_idx = prompt_idx[: req_len] prompt_idx = prompt_idx[: args.ctx_len]
prompt_token_ids.append( prompt_token_ids.append(
torch.tensor(prompt_idx, dtype=torch.long)) torch.tensor(prompt_idx, dtype=torch.long))
......
...@@ -508,7 +508,6 @@ class RWKV(pl.LightningModule): ...@@ -508,7 +508,6 @@ class RWKV(pl.LightningModule):
filter_logits_fn = top_k, filter_logits_fn = top_k,
filter_thres = 0.9, filter_thres = 0.9,
pad_value = 0., pad_value = 0.,
eos_token = None,
return_seq_without_prompt = True return_seq_without_prompt = True
): ):
''' 生成 response,用于 ppo 模型的训练 ''' 生成 response,用于 ppo 模型的训练
...@@ -521,7 +520,7 @@ class RWKV(pl.LightningModule): ...@@ -521,7 +520,7 @@ class RWKV(pl.LightningModule):
sample_num_times = max(1, seq_len - prompt.shape[-1]) sample_num_times = max(1, seq_len - prompt.shape[-1])
for _ in tqdm(range(sample_num_times), desc="gen responses"): for _ in tqdm(range(sample_num_times), desc="gen responses"):
pad_idx = torch.tensor([[eos_token] * (self.args.ctx_len - out.shape[-1])]) pad_idx = torch.tensor([[self.args.eos_token] * (self.args.ctx_len - out.shape[-1])])
query_idx = torch.cat((out, pad_idx), dim=-1) query_idx = torch.cat((out, pad_idx), dim=-1)
logits, embeds = self.forward(query_idx, ppo_train=True) logits, embeds = self.forward(query_idx, ppo_train=True)
logits, embeds = logits[:, -1], embeds[:, -1] logits, embeds = logits[:, -1], embeds[:, -1]
...@@ -532,8 +531,8 @@ class RWKV(pl.LightningModule): ...@@ -532,8 +531,8 @@ class RWKV(pl.LightningModule):
sample = gumbel_sample(logits, temperature = temperature, dim = -1) sample = gumbel_sample(logits, temperature = temperature, dim = -1)
out, _ = pack([out, sample], 'b *') out, _ = pack([out, sample], 'b *')
if exists(eos_token): if exists(self.args.eos_token):
is_eos_tokens = (out == eos_token) is_eos_tokens = (out == self.args.eos_token)
if is_eos_tokens.any(dim = -1).all(): if is_eos_tokens.any(dim = -1).all():
# mask out everything after the eos tokens # mask out everything after the eos tokens
......
...@@ -29,6 +29,8 @@ from src.model import RWKV ...@@ -29,6 +29,8 @@ from src.model import RWKV
from src.rlhf.reward import RewardModel from src.rlhf.reward import RewardModel
from src.rlhf.optimizer import get_optimizer from src.rlhf.optimizer import get_optimizer
from src.rlhf.utils import masked_mean, eval_decorator from src.rlhf.utils import masked_mean, eval_decorator
from src.dataset import load_prompt_data_4_ppo
from src.dataset import ExperienceDataset
# actor critic # actor critic
...@@ -52,12 +54,13 @@ class ActorCritic(nn.Module): ...@@ -52,12 +54,13 @@ class ActorCritic(nn.Module):
): ):
super().__init__() super().__init__()
self.args = args
self.actor = actor self.actor = actor
self.critic = critic self.critic = critic
self.pooled_values = pooled_values self.pooled_values = pooled_values
self.value_head = nn.Sequential( self.value_head = nn.Sequential(
nn.Linear(args.n_embd, 1), nn.Linear(self.args.n_embd, 1),
Rearrange('... 1 -> ...') Rearrange('... 1 -> ...')
) )
...@@ -70,14 +73,12 @@ class ActorCritic(nn.Module): ...@@ -70,14 +73,12 @@ class ActorCritic(nn.Module):
self, self,
state, state,
max_seq_len, max_seq_len,
eos_token = None,
return_values = False return_values = False
): ):
# 产生一条 response,相当于采取了一次 action # 产生一条 response,相当于采取了一次 action
actions = self.actor.generate( actions = self.actor.generate(
max_seq_len, max_seq_len,
prompt = state, prompt = state
eos_token = eos_token
) )
# 将 prompt (state) 和 response (action) 进行拼接 # 将 prompt (state) 和 response (action) 进行拼接
...@@ -93,8 +94,8 @@ class ActorCritic(nn.Module): ...@@ -93,8 +94,8 @@ class ActorCritic(nn.Module):
# 考虑 eos token # 考虑 eos token
mask = None mask = None
if exists(eos_token): if exists(self.args.eos_token):
mask = ((sequence == eos_token).cumsum(dim = -1) == 0) mask = ((sequence == self.args.eos_token).cumsum(dim = -1) == 0)
mask = F.pad(mask, (1, -1), value = True) # include eos token mask = F.pad(mask, (1, -1), value = True) # include eos token
action_mask &= mask action_mask &= mask
...@@ -143,27 +144,6 @@ class ActorCritic(nn.Module): ...@@ -143,27 +144,6 @@ class ActorCritic(nn.Module):
return action_logits, values return action_logits, values
@beartype
class ExperienceDataset(Dataset):
def __init__(
self,
data: List[torch.Tensor],
device = None
):
super().__init__()
self.data = data
self.device = device
def __len__(self):
return self.data[0].shape[0]
def __getitem__(self, ind):
return tuple(map(lambda t: t[ind].to(self.device), self.data))
def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs):
ds = ExperienceDataset(data, device = device)
return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs)
# helper functions # helper functions
def exists(val): def exists(val):
...@@ -230,7 +210,6 @@ def clipped_value_loss(values, rewards, old_values, clip): ...@@ -230,7 +210,6 @@ def clipped_value_loss(values, rewards, old_values, clip):
return torch.mean(torch.max(value_loss_1, value_loss_2)) return torch.mean(torch.max(value_loss_1, value_loss_2))
# rlhf # rlhf
@beartype @beartype
class RLHF(pl.LightningModule): class RLHF(pl.LightningModule):
def __init__( def __init__(
...@@ -244,6 +223,18 @@ class RLHF(pl.LightningModule): ...@@ -244,6 +223,18 @@ class RLHF(pl.LightningModule):
self.args = args self.args = args
# 读入 prompts 数据
self.prompts = load_prompt_data_4_ppo(args)
# 用于保存与 environment 的交互数据,用于训练 actor_critic (agent)
self.sequence_batch = []
self.prompt_mask_batch = []
self.mask_batch = []
self.action_prob_batch = []
self.action_log_prob_batch = []
self.reward_batch = []
self.value_batch = []
# 使用 RWKV 初始化 actor_critic # 使用 RWKV 初始化 actor_critic
actor_critic = ActorCritic( actor_critic = ActorCritic(
args=self.args, args=self.args,
...@@ -266,50 +257,105 @@ class RLHF(pl.LightningModule): ...@@ -266,50 +257,105 @@ class RLHF(pl.LightningModule):
def configure_optimizers(self): def configure_optimizers(self):
args = self.args args = self.args
optim_groups_actor = []
optim_groups_critic = []
if args.layerwise_lr > 0: if args.layerwise_lr > 0:
lr_1x = set() lr_1x_actor = set()
lr_2x = set() lr_2x_actor = set()
lr_3x = set() lr_3x_actor = set()
lr_1x_critic = set()
lr_2x_critic = set()
lr_3x_critic = set()
for n, p in self.named_parameters(): for n, p in self.named_parameters():
if "time_mix" in n: if "time_mix" in n:
if args.my_pile_stage == 2: if args.my_pile_stage == 2:
lr_2x.add(n) if "actor" in n:
lr_2x_actor.add(n)
elif "critic" in n:
lr_2x_critic.add(n)
else: else:
lr_1x.add(n) if "actor" in n:
lr_1x_actor.add(n)
elif "critic" in n:
lr_1x_critic.add(n)
elif "time_decay" in n: elif "time_decay" in n:
if args.my_pile_stage == 2: if args.my_pile_stage == 2:
lr_3x.add(n) if "actor" in n:
lr_3x_actor.add(n)
elif "critic" in n:
lr_3x_critic.add(n)
else: else:
lr_2x.add(n) if "actor" in n:
lr_2x_actor.add(n)
elif "critic" in n:
lr_2x_critic.add(n)
elif "time_first" in n: elif "time_first" in n:
lr_3x.add(n) if "actor" in n:
lr_3x_actor.add(n)
elif "critic" in n:
lr_3x_critic.add(n)
else: else:
lr_1x.add(n) if "actor" in n:
lr_1x = sorted(list(lr_1x)) lr_1x_actor.add(n)
lr_2x = sorted(list(lr_2x)) elif "critic" in n:
lr_3x = sorted(list(lr_3x)) lr_1x_critic.add(n)
lr_1x_actor = sorted(list(lr_1x_actor))
lr_2x_actor = sorted(list(lr_2x_actor))
lr_3x_actor = sorted(list(lr_3x_actor))
lr_1x_critic = sorted(list(lr_1x_critic))
lr_2x_critic = sorted(list(lr_2x_critic))
lr_3x_critic = sorted(list(lr_3x_critic))
param_dict = {n: p for n, p in self.named_parameters()} param_dict = {n: p for n, p in self.named_parameters()}
if args.my_pile_stage == 2: if args.my_pile_stage == 2:
optim_groups = [ optim_groups_actor = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, {"params": [param_dict[n] for n in lr_1x_actor], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, {"params": [param_dict[n] for n in lr_2x_actor], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, {"params": [param_dict[n] for n in lr_3x_actor], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
]
optim_groups_critic = [
{"params": [param_dict[n] for n in lr_1x_critic], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x_critic], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
{"params": [param_dict[n] for n in lr_3x_critic], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
] ]
else: else:
optim_groups = [ optim_groups_actor = [
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, {"params": [param_dict[n] for n in lr_1x_actor], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, {"params": [param_dict[n] for n in lr_2x_actor], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, {"params": [param_dict[n] for n in lr_3x_actor], "weight_decay": 0.0, "my_lr_scale": 3.0},
]
optim_groups_critic = [
{"params": [param_dict[n] for n in lr_1x_critic], "weight_decay": 0.0, "my_lr_scale": 1.0},
{"params": [param_dict[n] for n in lr_2x_critic], "weight_decay": 0.0, "my_lr_scale": 2.0},
{"params": [param_dict[n] for n in lr_3x_critic], "weight_decay": 0.0, "my_lr_scale": 3.0},
] ]
else: else:
optim_groups = [ optim_groups_actor = [
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, {"params": [p for n, p in self.named_parameters() if "actor" in n], "weight_decay": 0.0},
]
optim_groups_critic = [
{"params": [p for n, p in self.named_parameters() if "critic" in n], "weight_decay": 0.0},
] ]
if self.deepspeed_offload: if self.deepspeed_offload:
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) actor_optimizer = DeepSpeedCPUAdam(optim_groups_actor, lr=self.args.actor_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) critic_optimizer = DeepSpeedCPUAdam(optim_groups_critic, lr=self.args.critic_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
return actor_optimizer, critic_optimizer
actor_optimizer = FusedAdam(optim_groups_actor, lr=self.args.actor_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
critic_optimizer = FusedAdam(optim_groups_critic, lr=self.args.critic_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
return actor_optimizer, critic_optimizer
@property @property
def deepspeed_offload(self) -> bool: def deepspeed_offload(self) -> bool:
...@@ -360,7 +406,7 @@ class RLHF(pl.LightningModule): ...@@ -360,7 +406,7 @@ class RLHF(pl.LightningModule):
return best_sequence return best_sequence
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx, optimizer_idx):
sequences, \ sequences, \
prompt_masks, \ prompt_masks, \
masks, \ masks, \
...@@ -423,27 +469,34 @@ class RLHF(pl.LightningModule): ...@@ -423,27 +469,34 @@ class RLHF(pl.LightningModule):
policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies
# actor loss (也称为 policy loss, 是最终要使用模型的 loss) # actor loss (也称为 policy loss, 是最终要使用模型的 loss)
if optimizer_idx == 0:
actor_loss = policy_loss.mean() + kl_div_loss actor_loss = policy_loss.mean() + kl_div_loss
return actor_loss
# critic loss (也称为 value loss) # critic loss (也称为 value loss)
# update value network separate from policy network # update value network separate from policy network
if optimizer_idx == 1:
critic_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip) critic_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip)
critic_loss = critic_loss.mean() critic_loss = critic_loss.mean()
return critic_loss
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} def gen_experience_dataset(self):
def make_experience(self, prompts, eos_token=None, temperature=1):
''' 通过与 environment 交互产生训练数据 ''' 通过与 environment 交互产生训练数据
''' '''
device = self.device device = self.device
time_cnt = 0
for eps in tqdm(range(self.args.num_episodes), desc = 'episodes'):
for timestep in range(self.args.max_timesteps):
time_cnt += 1
# select a bunch of random states (prompts) # select a bunch of random states (prompts)
# and get the action (sampled sequence from rwkv as well as the action probs) # and get the action (sampled sequence from rwkv as well as the action probs)
# also calculate the reward using reward model and store # also calculate the reward using reward model and store
# 随机挑选一条 prompt # 随机挑选一条 prompt
rand_prompt_index = randrange(0, len(prompts)) rand_prompt_index = randrange(0, len(self.prompts))
state = prompts[rand_prompt_index] state = self.prompts[rand_prompt_index]
# remove padding from state # remove padding from state
state_mask = state != self.args.pad_value state_mask = state != self.args.pad_value
...@@ -463,7 +516,6 @@ class RLHF(pl.LightningModule): ...@@ -463,7 +516,6 @@ class RLHF(pl.LightningModule):
) = self.actor_critic.generate( ) = self.actor_critic.generate(
rearrange(state, 'n -> 1 n'), rearrange(state, 'n -> 1 n'),
max_seq_len = self.args.ctx_len, max_seq_len = self.args.ctx_len,
eos_token = eos_token,
return_values = True return_values = True
) )
action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
...@@ -493,12 +545,43 @@ class RLHF(pl.LightningModule): ...@@ -493,12 +545,43 @@ class RLHF(pl.LightningModule):
sample = True sample = True
) )
return ( self.sequence_batch.append(sequence)
sequence, self.prompt_mask_batch.append(prompt_mask)
prompt_mask, self.mask_batch.append(mask)
mask, self.action_prob_batch.append(action_prob)
action_prob, self.action_log_prob_batch.append(action_log_prob)
action_log_prob, self.reward_batch.append(reward)
reward, self.value_batch.append(value)
value
if time_cnt % self.args.update_timesteps == 0:
train_data = zip(
self.sequence_batch, self.prompt_mask_batch, self.mask_batch,
self.action_prob_batch, self.action_log_prob_batch, self.reward_batch,
self.value_batch
) )
for _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value in train_data:
yield _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value
self.sequence_batch.clear()
self.prompt_mask_batch.clear()
self.mask_batch.clear()
self.action_prob_batch.clear()
self.action_log_prob_batch.clear()
self.reward_batch.clear()
self.value_batch.clear()
def _dataloader(self) -> DataLoader:
''' Initialize the Replay Buffer dataset used for retrieving experiences '''
dataset = ExperienceDataset(self.gen_experience_dataset)
dataloader = DataLoader(dataset=dataset, batch_size=self.args.micro_bsz)
return dataloader
def train_dataloader(self) -> DataLoader:
''' Get train loader '''
return self._dataloader()
...@@ -138,6 +138,7 @@ if __name__ == "__main__": ...@@ -138,6 +138,7 @@ if __name__ == "__main__":
parser.add_argument("--num_episodes", default=50000, type=int) parser.add_argument("--num_episodes", default=50000, type=int)
parser.add_argument("--max_timesteps", default=500, type=int) parser.add_argument("--max_timesteps", default=500, type=int)
parser.add_argument("--update_timesteps", default=5000, type=int) parser.add_argument("--update_timesteps", default=5000, type=int)
parser.add_argument("--eos_token", default=0, type=int)
parser = Trainer.add_argparse_args(parser) parser = Trainer.add_argparse_args(parser)
...@@ -249,7 +250,6 @@ if __name__ == "__main__": ...@@ -249,7 +250,6 @@ if __name__ == "__main__":
from collections import deque, namedtuple from collections import deque, namedtuple
from einops import rearrange from einops import rearrange
from src.dataset import PPODataset, load_prompt_data_4_ppo
from src.rlhf.ppo import RLHF from src.rlhf.ppo import RLHF
from src.trainer import rlhf_train_callback from src.trainer import rlhf_train_callback
from src.model import RWKV from src.model import RWKV
...@@ -258,9 +258,6 @@ if __name__ == "__main__": ...@@ -258,9 +258,6 @@ if __name__ == "__main__":
# 用于 PPO 训练的数据,需要与 environment 交互获得 # 用于 PPO 训练的数据,需要与 environment 交互获得
memory = [] memory = []
# 读入训练数据集
prompts = load_prompt_data_4_ppo(args)
# 用 rwkv 初始化 actor 模型 # 用 rwkv 初始化 actor 模型
actor = RWKV(args) actor = RWKV(args)
actor.load(args.load_sft_model) actor.load(args.load_sft_model)
...@@ -298,21 +295,7 @@ if __name__ == "__main__": ...@@ -298,21 +295,7 @@ if __name__ == "__main__":
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
time_cnt = 0 trainer.fit(rlhf_model)
for eps in tqdm(range(args.num_episodes), desc = 'episodes'):
for timestep in range(args.max_timesteps):
time_cnt += 1
# 生成 ppo 模型的训练数据
experience_data = rlhf_model.make_experience(prompts, eos_token=0)
memory.append(experience_data)
# learn from the stored memories
if time_cnt % args.update_timesteps == 0:
train_data = PPODataset(memory)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
trainer.fit(rlhf_model, data_loader)
print('rlhf training complete') print('rlhf training complete')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册