From be4440dad87b13dc0457cc140420ed949a1bab99 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Wed, 22 Mar 2023 18:06:00 +0800 Subject: [PATCH] transfer ppo code to pytorch_lightning style --- src/dataset.py | 32 ++--- src/model.py | 7 +- src/rlhf/ppo.py | 327 ++++++++++++++++++++++++++++++------------------ train_ppo.py | 21 +--- 4 files changed, 221 insertions(+), 166 deletions(-) diff --git a/src/dataset.py b/src/dataset.py index 399635d..c492408 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -12,6 +12,9 @@ from src.utils import TOKENIZER from .binidx import MMapIndexedDataset from .utils import MaybeIsPrime +from typing import Iterable, Callable +from torch.utils.data import IterableDataset + class MyDataset(Dataset): def __init__(self, args): @@ -324,27 +327,16 @@ class RMDataset(Dataset): m_a = torch.tensor(prompt_alter_mask, dtype=torch.long) return x_p, x_a, m_p, m_a - -class PPODataset(Dataset): - def __init__(self, memory): - self.data = memory - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - # todo(luxin) 是否需要 padding ??? - sequence, \ - prompt_mask, \ - mask, \ - action_prob, \ - action_log_prob, \ - reward, \ - value = self.data[index] +class ExperienceDataset(IterableDataset): + def __init__(self, generate_batch: Callable): + super().__init__() + self.generate_batch = generate_batch - return sequence, prompt_mask, mask, action_prob, action_log_prob, reward, value + def __iter__(self) -> Iterable: + iterator = self.generate_batch() + return iterator def load_prompt_data_4_ppo(args): @@ -356,14 +348,12 @@ def load_prompt_data_4_ppo(args): ] # [vocab, vocab] for Pile model tokenizer = TOKENIZER(WORD_NAME) - ctx_len = args.ctx_len - req_len = ctx_len pf = pd.read_csv(args.data_file) for index, row in pf.iterrows(): prompt = row["prompt"] prompt_idx = tokenizer.tokenizer.encode(prompt) - prompt_idx = prompt_idx[: req_len] + prompt_idx = prompt_idx[: args.ctx_len] prompt_token_ids.append( torch.tensor(prompt_idx, dtype=torch.long)) diff --git a/src/model.py b/src/model.py index 9322ba1..babaf5c 100644 --- a/src/model.py +++ b/src/model.py @@ -508,7 +508,6 @@ class RWKV(pl.LightningModule): filter_logits_fn = top_k, filter_thres = 0.9, pad_value = 0., - eos_token = None, return_seq_without_prompt = True ): ''' 生成 response,用于 ppo 模型的训练 @@ -521,7 +520,7 @@ class RWKV(pl.LightningModule): sample_num_times = max(1, seq_len - prompt.shape[-1]) for _ in tqdm(range(sample_num_times), desc="gen responses"): - pad_idx = torch.tensor([[eos_token] * (self.args.ctx_len - out.shape[-1])]) + pad_idx = torch.tensor([[self.args.eos_token] * (self.args.ctx_len - out.shape[-1])]) query_idx = torch.cat((out, pad_idx), dim=-1) logits, embeds = self.forward(query_idx, ppo_train=True) logits, embeds = logits[:, -1], embeds[:, -1] @@ -532,8 +531,8 @@ class RWKV(pl.LightningModule): sample = gumbel_sample(logits, temperature = temperature, dim = -1) out, _ = pack([out, sample], 'b *') - if exists(eos_token): - is_eos_tokens = (out == eos_token) + if exists(self.args.eos_token): + is_eos_tokens = (out == self.args.eos_token) if is_eos_tokens.any(dim = -1).all(): # mask out everything after the eos tokens diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index ef8fab1..4708954 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -29,6 +29,8 @@ from src.model import RWKV from src.rlhf.reward import RewardModel from src.rlhf.optimizer import get_optimizer from src.rlhf.utils import masked_mean, eval_decorator +from src.dataset import load_prompt_data_4_ppo +from src.dataset import ExperienceDataset # actor critic @@ -52,12 +54,13 @@ class ActorCritic(nn.Module): ): super().__init__() + self.args = args self.actor = actor self.critic = critic self.pooled_values = pooled_values self.value_head = nn.Sequential( - nn.Linear(args.n_embd, 1), + nn.Linear(self.args.n_embd, 1), Rearrange('... 1 -> ...') ) @@ -70,14 +73,12 @@ class ActorCritic(nn.Module): self, state, max_seq_len, - eos_token = None, return_values = False ): # 产生一条 response,相当于采取了一次 action actions = self.actor.generate( max_seq_len, - prompt = state, - eos_token = eos_token + prompt = state ) # 将 prompt (state) 和 response (action) 进行拼接 @@ -93,8 +94,8 @@ class ActorCritic(nn.Module): # 考虑 eos token mask = None - if exists(eos_token): - mask = ((sequence == eos_token).cumsum(dim = -1) == 0) + if exists(self.args.eos_token): + mask = ((sequence == self.args.eos_token).cumsum(dim = -1) == 0) mask = F.pad(mask, (1, -1), value = True) # include eos token action_mask &= mask @@ -143,27 +144,6 @@ class ActorCritic(nn.Module): return action_logits, values -@beartype -class ExperienceDataset(Dataset): - def __init__( - self, - data: List[torch.Tensor], - device = None - ): - super().__init__() - self.data = data - self.device = device - - def __len__(self): - return self.data[0].shape[0] - - def __getitem__(self, ind): - return tuple(map(lambda t: t[ind].to(self.device), self.data)) - -def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs): - ds = ExperienceDataset(data, device = device) - return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs) - # helper functions def exists(val): @@ -230,7 +210,6 @@ def clipped_value_loss(values, rewards, old_values, clip): return torch.mean(torch.max(value_loss_1, value_loss_2)) # rlhf - @beartype class RLHF(pl.LightningModule): def __init__( @@ -244,6 +223,18 @@ class RLHF(pl.LightningModule): self.args = args + # 读入 prompts 数据 + self.prompts = load_prompt_data_4_ppo(args) + + # 用于保存与 environment 的交互数据,用于训练 actor_critic (agent) + self.sequence_batch = [] + self.prompt_mask_batch = [] + self.mask_batch = [] + self.action_prob_batch = [] + self.action_log_prob_batch = [] + self.reward_batch = [] + self.value_batch = [] + # 使用 RWKV 初始化 actor_critic actor_critic = ActorCritic( args=self.args, @@ -266,50 +257,105 @@ class RLHF(pl.LightningModule): def configure_optimizers(self): args = self.args + + optim_groups_actor = [] + optim_groups_critic = [] + if args.layerwise_lr > 0: - lr_1x = set() - lr_2x = set() - lr_3x = set() + lr_1x_actor = set() + lr_2x_actor = set() + lr_3x_actor = set() + + lr_1x_critic = set() + lr_2x_critic = set() + lr_3x_critic = set() + for n, p in self.named_parameters(): if "time_mix" in n: if args.my_pile_stage == 2: - lr_2x.add(n) + if "actor" in n: + lr_2x_actor.add(n) + elif "critic" in n: + lr_2x_critic.add(n) else: - lr_1x.add(n) + if "actor" in n: + lr_1x_actor.add(n) + elif "critic" in n: + lr_1x_critic.add(n) elif "time_decay" in n: if args.my_pile_stage == 2: - lr_3x.add(n) + if "actor" in n: + lr_3x_actor.add(n) + elif "critic" in n: + lr_3x_critic.add(n) else: - lr_2x.add(n) + if "actor" in n: + lr_2x_actor.add(n) + elif "critic" in n: + lr_2x_critic.add(n) elif "time_first" in n: - lr_3x.add(n) + if "actor" in n: + lr_3x_actor.add(n) + elif "critic" in n: + lr_3x_critic.add(n) else: - lr_1x.add(n) - lr_1x = sorted(list(lr_1x)) - lr_2x = sorted(list(lr_2x)) - lr_3x = sorted(list(lr_3x)) + if "actor" in n: + lr_1x_actor.add(n) + elif "critic" in n: + lr_1x_critic.add(n) + + lr_1x_actor = sorted(list(lr_1x_actor)) + lr_2x_actor = sorted(list(lr_2x_actor)) + lr_3x_actor = sorted(list(lr_3x_actor)) + + lr_1x_critic = sorted(list(lr_1x_critic)) + lr_2x_critic = sorted(list(lr_2x_critic)) + lr_3x_critic = sorted(list(lr_3x_critic)) + param_dict = {n: p for n, p in self.named_parameters()} if args.my_pile_stage == 2: - optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, + optim_groups_actor = [ + {"params": [param_dict[n] for n in lr_1x_actor], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x_actor], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, + {"params": [param_dict[n] for n in lr_3x_actor], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, + ] + + optim_groups_critic = [ + {"params": [param_dict[n] for n in lr_1x_critic], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x_critic], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, + {"params": [param_dict[n] for n in lr_3x_critic], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, ] else: - optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + optim_groups_actor = [ + {"params": [param_dict[n] for n in lr_1x_actor], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x_actor], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x_actor], "weight_decay": 0.0, "my_lr_scale": 3.0}, + ] + + optim_groups_critic = [ + {"params": [param_dict[n] for n in lr_1x_critic], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x_critic], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x_critic], "weight_decay": 0.0, "my_lr_scale": 3.0}, ] else: - optim_groups = [ - {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + optim_groups_actor = [ + {"params": [p for n, p in self.named_parameters() if "actor" in n], "weight_decay": 0.0}, + ] + + optim_groups_critic = [ + {"params": [p for n, p in self.named_parameters() if "critic" in n], "weight_decay": 0.0}, ] if self.deepspeed_offload: - return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) - return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) - # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + actor_optimizer = DeepSpeedCPUAdam(optim_groups_actor, lr=self.args.actor_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) + critic_optimizer = DeepSpeedCPUAdam(optim_groups_critic, lr=self.args.critic_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) + + return actor_optimizer, critic_optimizer + + actor_optimizer = FusedAdam(optim_groups_actor, lr=self.args.actor_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + critic_optimizer = FusedAdam(optim_groups_critic, lr=self.args.critic_lr, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + + return actor_optimizer, critic_optimizer @property def deepspeed_offload(self) -> bool: @@ -360,7 +406,7 @@ class RLHF(pl.LightningModule): return best_sequence - def training_step(self, batch, batch_idx): + def training_step(self, batch, batch_idx, optimizer_idx): sequences, \ prompt_masks, \ masks, \ @@ -423,82 +469,119 @@ class RLHF(pl.LightningModule): policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies # actor loss (也称为 policy loss, 是最终要使用模型的 loss) - actor_loss = policy_loss.mean() + kl_div_loss + if optimizer_idx == 0: + actor_loss = policy_loss.mean() + kl_div_loss + return actor_loss # critic loss (也称为 value loss) # update value network separate from policy network - critic_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip) - critic_loss = critic_loss.mean() + if optimizer_idx == 1: + critic_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip) + critic_loss = critic_loss.mean() + return critic_loss - return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} - - def make_experience(self, prompts, eos_token=None, temperature=1): + def gen_experience_dataset(self): ''' 通过与 environment 交互产生训练数据 ''' device = self.device - # select a bunch of random states (prompts) - # and get the action (sampled sequence from rwkv as well as the action probs) - # also calculate the reward using reward model and store - # 随机挑选一条 prompt - rand_prompt_index = randrange(0, len(prompts)) - state = prompts[rand_prompt_index] - - # remove padding from state - state_mask = state != self.args.pad_value - state = state[state_mask] - - # get predicted sequence - # 与 environment 进行交互,其中返回的: - # action 是 response, - # sequence 是 prompt + response, - ( - actions, - sequence, - mask, - prompt_mask, - action_logits, - value - ) = self.actor_critic.generate( - rearrange(state, 'n -> 1 n'), - max_seq_len = self.args.ctx_len, - eos_token = eos_token, - return_values = True - ) - action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token - - action_prob = action_logits.softmax(dim = -1) - - action_len = actions.shape[-1] - action_log_prob = log_prob(action_prob, sequence) - action_log_prob = action_log_prob[:, -action_len:] - - actions = rearrange(actions, '1 ... -> ...') + time_cnt = 0 + for eps in tqdm(range(self.args.num_episodes), desc = 'episodes'): + for timestep in range(self.args.max_timesteps): + time_cnt += 1 + + # select a bunch of random states (prompts) + # and get the action (sampled sequence from rwkv as well as the action probs) + # also calculate the reward using reward model and store + # 随机挑选一条 prompt + rand_prompt_index = randrange(0, len(self.prompts)) + state = self.prompts[rand_prompt_index] + + # remove padding from state + state_mask = state != self.args.pad_value + state = state[state_mask] + + # get predicted sequence + # 与 environment 进行交互,其中返回的: + # action 是 response, + # sequence 是 prompt + response, + ( + actions, + sequence, + mask, + prompt_mask, + action_logits, + value + ) = self.actor_critic.generate( + rearrange(state, 'n -> 1 n'), + max_seq_len = self.args.ctx_len, + return_values = True + ) + action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token + + action_prob = action_logits.softmax(dim = -1) + + action_len = actions.shape[-1] + action_log_prob = log_prob(action_prob, sequence) + action_log_prob = action_log_prob[:, -action_len:] + + actions = rearrange(actions, '1 ... -> ...') + + # get reward as given by supervised trained reward model + sequence = torch.cat((state, actions), dim = 0) + + prompt_length = len(state) + prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length + + sequence = rearrange(sequence, 'n -> 1 n') + prompt_mask = rearrange(prompt_mask, 'n -> 1 n') + mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device) + + reward = self.reward_model( + sequence, + prompt_mask = prompt_mask, + mask = mask, + sample = True + ) + + self.sequence_batch.append(sequence) + self.prompt_mask_batch.append(prompt_mask) + self.mask_batch.append(mask) + self.action_prob_batch.append(action_prob) + self.action_log_prob_batch.append(action_log_prob) + self.reward_batch.append(reward) + self.value_batch.append(value) + + if time_cnt % self.args.update_timesteps == 0: + train_data = zip( + self.sequence_batch, self.prompt_mask_batch, self.mask_batch, + self.action_prob_batch, self.action_log_prob_batch, self.reward_batch, + self.value_batch + ) + + for _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value in train_data: + yield _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value + + self.sequence_batch.clear() + self.prompt_mask_batch.clear() + self.mask_batch.clear() + self.action_prob_batch.clear() + self.action_log_prob_batch.clear() + self.reward_batch.clear() + self.value_batch.clear() + + + def _dataloader(self) -> DataLoader: + ''' Initialize the Replay Buffer dataset used for retrieving experiences ''' + + dataset = ExperienceDataset(self.gen_experience_dataset) + dataloader = DataLoader(dataset=dataset, batch_size=self.args.micro_bsz) + return dataloader + + def train_dataloader(self) -> DataLoader: + ''' Get train loader ''' + + return self._dataloader() - # get reward as given by supervised trained reward model - sequence = torch.cat((state, actions), dim = 0) - prompt_length = len(state) - prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length - - sequence = rearrange(sequence, 'n -> 1 n') - prompt_mask = rearrange(prompt_mask, 'n -> 1 n') - mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device) - - reward = self.reward_model( - sequence, - prompt_mask = prompt_mask, - mask = mask, - sample = True - ) - - return ( - sequence, - prompt_mask, - mask, - action_prob, - action_log_prob, - reward, - value - ) diff --git a/train_ppo.py b/train_ppo.py index e8c8bae..2c69f47 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -138,6 +138,7 @@ if __name__ == "__main__": parser.add_argument("--num_episodes", default=50000, type=int) parser.add_argument("--max_timesteps", default=500, type=int) parser.add_argument("--update_timesteps", default=5000, type=int) + parser.add_argument("--eos_token", default=0, type=int) parser = Trainer.add_argparse_args(parser) @@ -249,7 +250,6 @@ if __name__ == "__main__": from collections import deque, namedtuple from einops import rearrange - from src.dataset import PPODataset, load_prompt_data_4_ppo from src.rlhf.ppo import RLHF from src.trainer import rlhf_train_callback from src.model import RWKV @@ -258,9 +258,6 @@ if __name__ == "__main__": # 用于 PPO 训练的数据,需要与 environment 交互获得 memory = [] - # 读入训练数据集 - prompts = load_prompt_data_4_ppo(args) - # 用 rwkv 初始化 actor 模型 actor = RWKV(args) actor.load(args.load_sft_model) @@ -298,21 +295,7 @@ if __name__ == "__main__": trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - time_cnt = 0 - for eps in tqdm(range(args.num_episodes), desc = 'episodes'): - for timestep in range(args.max_timesteps): - time_cnt += 1 - - # 生成 ppo 模型的训练数据 - experience_data = rlhf_model.make_experience(prompts, eos_token=0) - memory.append(experience_data) - - # learn from the stored memories - if time_cnt % args.update_timesteps == 0: - train_data = PPODataset(memory) - data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) - - trainer.fit(rlhf_model, data_loader) + trainer.fit(rlhf_model) print('rlhf training complete') -- GitLab