diff --git a/src/rlhf/ppo_old.py b/src/rlhf/ppo_old.py deleted file mode 100644 index 2a8fbda986140bdbfbf6d7b02f8d40da398256ab..0000000000000000000000000000000000000000 --- a/src/rlhf/ppo_old.py +++ /dev/null @@ -1,623 +0,0 @@ -import math -from pathlib import Path -import copy -from tqdm import tqdm -from functools import partial -from collections import deque, namedtuple -from random import randrange - -from beartype import beartype -from beartype.typing import List, Optional, Callable, Deque - -import torch -from torch import nn -import torch.nn.functional as F - -from torch.optim import Adam -from torch.utils.data import Dataset, DataLoader -from torch.nn.utils.rnn import pad_sequence - -from pytorch_lightning.utilities import rank_zero_info - -from einops import rearrange, repeat -from einops.layers.torch import Rearrange - -from src.model import RWKV -from src.rlhf.reward import RewardModel -from src.rlhf.optimizer import get_optimizer -from src.rlhf.utils import masked_mean, eval_decorator - -from accelerate import Accelerator - -# actor critic - PaLM with lora - -PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ - 'actions', - 'sequence', - 'mask', - 'prompt_mask', - 'action_logits', - 'values' -]) - -@beartype -class ActorCritic(nn.Module): - def __init__( - self, - rwkv: RWKV, - critic_palm: Optional[RWKV] = None, - pooled_values = False, - actor_lora = True, - critic_lora = True, - actor_lora_r = 8, - critic_lora_r = 8, - actor_lora_scope = 'actor', - critic_lora_scope = 'critic', - actor_dropout = 0., - critic_dropout = 0. - ): - super().__init__() - self.actor_palm = rwkv - - self.critic_palm = critic_palm - - if not exists(self.critic_palm): - self.critic_palm = copy.deepcopy(rwkv) - - self.actor_palm.set_dropout(actor_dropout) - self.critic_palm.set_dropout(critic_dropout) - - self.actor_lora = actor_lora - self.critic_lora = critic_lora - - self.actor_lora_scope = actor_lora_scope if actor_lora else None - self.critic_lora_scope = critic_lora_scope if critic_lora else None - - if self.actor_lora: - self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r) - - if self.critic_lora: - self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r) - - self.pooled_values = pooled_values - self.value_head = nn.Sequential( - nn.Linear(rwkv.dim, 1), - Rearrange('... 1 -> ...') - ) - - nn.init.zeros_(self.value_head[0].bias) - nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2)) - - def actor_parameters(self): - if not self.actor_lora: - return self.actor_palm.parameters() - - return [ - *self.actor_palm.finetune_parameters(self.actor_lora_scope) - ] - - def critic_parameters(self): - if not self.actor_lora: - return [*self.critic_palm.parameters(), *self.value_head.parameters()] - - return [ - *self.critic_palm.finetune_parameters(self.critic_lora_scope), - *self.value_head.parameters() - ] - - @torch.no_grad() - @eval_decorator - def generate( - self, - state, - max_seq_len, - eos_token = None, - return_values = False, - **kwargs - ): - actions = self.actor_palm.generate( - max_seq_len, - prompt = state, - eos_token = eos_token, - finetune_scope = self.actor_lora_scope, - use_tqdm = True, - **kwargs - ) - - sequence = torch.cat((state, actions), dim = -1) - action_len = actions.shape[-1] - state_len = state.shape[-1] - - prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len - prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0]) - - action_mask = ~prompt_mask - - mask = None - if exists(eos_token): - mask = ((sequence == eos_token).cumsum(dim = -1) == 0) - mask = F.pad(mask, (1, -1), value = True) # include eos token - action_mask &= mask - - action_logits, value = self.forward( - sequence, - mask = action_mask, - return_values = return_values - ) - - return PPOActionCriticReturn( - actions, - sequence, - mask, - prompt_mask, - action_logits, - value - ) - - def forward( - self, - x, - mask = None, - return_values = True - ): - action_logits = self.actor_palm( - x, - finetune_scope = self.actor_lora_scope - ) - - if not return_values: - return action_logits, None - - critic_embeds = self.critic_palm( - x, - return_only_embedding = True, - finetune_scope = self.critic_lora_scope - ) - - if self.pooled_values: - critic_embeds = shift(critic_embeds, shift = 1, dim = -2) - critic_embeds = masked_mean(critic_embeds, mask, dim = 1) - - values = self.value_head(critic_embeds) - - return action_logits, values - -# data - -Memory = namedtuple('Memory', [ - 'sequence', - 'prompt_mask', - 'mask', - 'action_prob', - 'action_log_prob', - 'reward', - 'value' -]) - -@beartype -class ExperienceDataset(Dataset): - def __init__( - self, - data: List[torch.Tensor], - device = None - ): - super().__init__() - self.data = data - self.device = device - - def __len__(self): - return self.data[0].shape[0] - - def __getitem__(self, ind): - return tuple(map(lambda t: t[ind].to(self.device), self.data)) - -def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs): - ds = ExperienceDataset(data, device = device) - return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs) - -# helper functions - -def exists(val): - return val is not None - -def default(val, d): - return val if exists(val) else d - -def masked_normalize(t, eps = 1e-5, mask = None, dim = None): - dim = default(dim, tuple(range(t.ndim))) - kwargs = dict(dim = dim, keepdim = True) - - mean = masked_mean(t, mask = mask, **kwargs) - mean_centered = t - mean - var = masked_mean(mean_centered ** 2, mask = mask, **kwargs) - - return mean_centered * var.clamp(min = eps).rsqrt() - -def pad_sequence_fixed(sequences, *args, **kwargs): - first_el = sequences[0] - has_no_dimension = first_el.ndim == 0 - - # if no dimensions, add a single dimension - if has_no_dimension: - sequences = tuple(map(lambda t: t[None], sequences)) - - out = pad_sequence(sequences, *args, **kwargs) - - if has_no_dimension: - out = rearrange(out, '... 1 -> ...') - - return out - -def log(t, eps = 1e-20): - return torch.log(t.clamp(min = eps)) - -def log_prob(prob, indices): - assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match' - return log(prob.gather(-1, indices[..., None])).squeeze(-1) - -def shift(t, value = 0, shift = 1, dim = -1): - zeros = (0, 0) * (-dim - 1) - return F.pad(t, (*zeros, shift, -shift), value = value) - -def masked_entropy(prob, dim = -1, mask = None): - entropies = (prob * log(prob)).sum(dim = -1) - return masked_mean(entropies, mask = mask).mean() - -def masked_kl_div(prob1, prob2, mask = None): - """ - need to account for variable sequence lengths, therefore not using the built-in functional version - """ - kl_divs = (prob1 * (log(prob2) - log(prob1))).sum(dim = -1) - - if not exists(mask): - return kl_divs.mean() - - return masked_mean(kl_divs, mask).mean() - -def clipped_value_loss(values, rewards, old_values, clip): - value_clipped = old_values + (values - old_values).clamp(-clip, clip) - value_loss_1 = (value_clipped.flatten() - rewards) ** 2 - value_loss_2 = (values.flatten() - rewards) ** 2 - return torch.mean(torch.max(value_loss_1, value_loss_2)) - -# rlhf trainer - -@beartype -class RLHFTrainer(nn.Module): - def __init__( - self, - args, - accelerate_kwargs: dict = {} - ): - super().__init__() - - self.args = args - - self.accelerate = Accelerator(**accelerate_kwargs) - - # 加载 RWKV 模型 - rwkv = RWKV(args) - - if len(args.load_sft_model) == 0: - rank_zero_info(f"SFT must load model, please input ") - exit(1) - - rank_zero_info(f"########## Loading {args.load_sft_model}... ##########") - try: - load_dict = torch.load(args.load_sft_model, map_location="cpu") - except: - rank_zero_info(f"Bad checkpoint {args.load_sft_model}") - exit(1) - - if args.load_partial == 1: - load_keys = load_dict.keys() - for k in rwkv.state_dict(): - if k not in load_keys: - load_dict[k] = rwkv.state_dict()[k] - rwkv.load_state_dict(load_dict) - - self.rwkv = rwkv - - # 使用 RWKV 初始化 actor_critic - actor_critic = ActorCritic( - rwkv = self.rwkv, - actor_lora = args.actor_lora, - critic_lora = args.critic_lora, - actor_lora_r = args.actor_lora_r, - critic_lora_r = args.critic_lora_r, - pooled_values = args.critic_pooled_values, - actor_dropout = args.actor_dropout, - critic_dropout = args.critic_dropout - ).to(self.rwkv.device) - - self.actor_critic = actor_critic - - # 加载 reward_model,并将 reward_model 设置为 evaluation 模式 - reward_model = RewardModel(args) - reward_model.load(args.load_rm_model) - self.reward_model = reward_model.eval() - - # optimizers - self.actor_optim = get_optimizer(actor_critic.actor_parameters(), lr = self.args.actor_lr, wd = self.args.actor_wd, betas = self.args.betas, eps = self.args.actor_adam_eps, use_lion = self.args.use_lion) - self.critic_optim = get_optimizer(actor_critic.critic_parameters(), lr = self.args.critic_lr, wd = self.args.critic_wd, betas = self.args.betas, eps = self.args.critic_adam_eps, use_lion = self.args.use_lion) - - # prepare with accelerator - - ( - self.actor_critic, - self.reward_model, - self.actor_optim, - self.critic_optim - ) = self.accelerate.prepare( - self.actor_critic, - self.reward_model, - self.actor_optim, - self.critic_optim - ) - - - def print(self, msg): - return self.accelerate.print(msg) - - def save(self, filepath = './checkpoint.pt'): - torch.save(self.actor_critic.state_dict(), filepath) - - def load(self, filepath = './checkpoint.pt'): - state_dict = torch.load(filepath) - self.actor_critic.load_state_dict(state_dict) - - @property - def device(self): - return self.accelerate.device - - @torch.no_grad() - def generate( - self, - max_seq_len, - *args, - prompt, - num_samples = 4, # sample 4 per prompt and select the one with highest reward - **kwargs - ): - assert prompt.ndim == 1, 'only one prompt allowed at a time for now' - prompt = repeat(prompt, 'n -> b n', b = num_samples) - - actor_critic = self.accelerate.unwrap_model(self.actor_critic) - reward_model = self.accelerate.unwrap_model(self.reward_model) - - actor_critic.eval() - - ( - actions, - sequences, - mask, - prompt_mask, - action_logits, - _ - ) = actor_critic.generate( - prompt, - *args, - max_seq_len = max_seq_len, - return_values = False, - **kwargs - ) - - rewards = reward_model( - sequences, - prompt_mask = prompt_mask, - mask = mask, - sample = True - ) - - best_sequence_index = rewards.topk(1, dim = -1).indices - - best_sequence = sequences[best_sequence_index] - best_sequence = rearrange(best_sequence, '1 ... -> ...') - - return best_sequence - - def learn( - self, - memories: Deque[Memory] - ): - # stack all data stored in the memories - - all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories))) - - # prepare dataloader for policy phase training - - dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device) - - self.actor_critic.train() - - # PPO training - - for _ in range(self.epochs): - for ( - sequences, - prompt_masks, - masks, - old_action_probs, - old_log_probs, - rewards, - old_values - ) in dl: - action_masks = ~prompt_masks & masks - - action_logits, values = self.actor_critic( - sequences, - mask = action_masks - ) - - action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token - action_len = old_log_probs.shape[-1] - - action_probs = action_logits.softmax(dim = -1) - action_log_probs = log_prob(action_probs, sequences) - action_log_probs = action_log_probs[:, -action_len:] - - # calculate entropies, taking into account which part of the sequence is actually an action - - entropies = masked_entropy(action_probs, mask = action_masks) - - # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not - - kl_div_loss = 0. - - if self.args.kl_div_loss_weight > 0: - kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.args.kl_div_loss_weight - - # handle non-pooled values - - normalize_kwargs = dict() - - if old_values.ndim == 2: - old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values)) - - old_values = old_values[:, -action_len:] - values = values[:, -action_len:] - rewards = rearrange(rewards, 'b -> b 1') - normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:]) - - if values.ndim < rewards.ndim: - values = rearrange(values, '... -> ... 1') - - # calculate clipped surrogate objective, classic PPO loss - - ratios = (action_log_probs - old_log_probs).exp() - advantages = masked_normalize(rewards - old_values, **normalize_kwargs) - - if advantages.ndim == 1: - advantages = rearrange(advantages, 'b -> b 1') - - surr1 = ratios * advantages - surr2 = ratios.clamp(1 - self.args.eps_clip, 1 + self.args.eps_clip) * advantages - policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies - - # combine losses - - loss = policy_loss.mean() + kl_div_loss - - # update actor - - self.accelerate.backward(loss) - - self.print(f'policy_loss: {loss.item():.3f}') - - if exists(self.args.max_norm): - self.accelerator.clip_grad_norm_(self.actor_critic.actor_parameters(), self.args.max_norm) - - self.actor_optim.step() - self.actor_optim.zero_grad() - - # calculate value loss and update value network separate from policy network - - value_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip) - value_loss = value_loss.mean() - - self.print(f'critic_loss: {value_loss.item():.3f}') - - self.accelerate.backward(value_loss) - - if exists(self.args.max_norm): - self.accelerator.clip_grad_norm_(self.actor_critic.critic_parameters(), self.args.max_norm) - - self.critic_optim.step() - self.critic_optim.zero_grad() - - def train( - self, - num_episodes = 50000, - max_timesteps = 500, - update_timesteps = 5000, - max_batch_size = 16, - eos_token = None, - temperature = 1. - ): - device = self.device - - time = 0 - memories = deque([]) - - for eps in tqdm(range(num_episodes), desc = 'episodes'): - for timestep in range(max_timesteps): - time += 1 - - # select a bunch of random states (prompts) - # and get the action (sampled sequence from palm as well as the action probs) - # also calculate the reward using reward model and store - - rand_prompt_index = randrange(0, self.num_prompts) - - state = self.prompt_token_ids[rand_prompt_index] - - # remove padding from state - - state_mask = state != self.args.pad_value - state = state[state_mask] - - # get predicted sequence - - ( - actions, - sequence, - mask, - prompt_mask, - action_logits, - value - ) = self.actor_critic.generate( - rearrange(state, 'n -> 1 n'), - max_seq_len = self.args.ctx_len, - eos_token = eos_token, - temperature = temperature, - return_values = True - ) - action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token - - action_prob = action_logits.softmax(dim = -1) - - action_len = actions.shape[-1] - action_log_prob = log_prob(action_prob, sequence) - action_log_prob = action_log_prob[:, -action_len:] - - actions = rearrange(actions, '1 ... -> ...') - - # get reward as given by supervised trained reward model - - sequence = torch.cat((state, actions), dim = 0) - - prompt_length = len(state) - prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length - - sequence = rearrange(sequence, 'n -> 1 n') - prompt_mask = rearrange(prompt_mask, 'n -> 1 n') - mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device) - - reward = self.reward_model( - sequence, - prompt_mask = prompt_mask, - mask = mask, - sample = True - ) - - detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...') - - # store memory for learning - - memories.append(Memory(*map(detach_to_cpu_, ( - sequence, - prompt_mask, - mask, - action_prob, - action_log_prob, - reward, - value - )))) - - # learn from the stored memories - - if time % update_timesteps == 0: - self.learn(memories) - memories.clear() - - print('rlhf training complete')