import math from pathlib import Path import copy from tqdm import tqdm from functools import partial from collections import deque, namedtuple from random import randrange from beartype import beartype from beartype.typing import List, Optional, Callable, Deque import torch from torch import nn import torch.nn.functional as F from torch.optim import Adam from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence from pytorch_lightning.utilities import rank_zero_info from einops import rearrange, repeat from einops.layers.torch import Rearrange from src.model import RWKV from src.rlhf.reward import RewardModel from src.rlhf.optimizer import get_optimizer from src.rlhf.utils import masked_mean, eval_decorator from accelerate import Accelerator # actor critic - PaLM with lora PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ 'actions', 'sequence', 'mask', 'prompt_mask', 'action_logits', 'values' ]) @beartype class ActorCritic(nn.Module): def __init__( self, rwkv: RWKV, critic_palm: Optional[RWKV] = None, pooled_values = False, actor_lora = True, critic_lora = True, actor_lora_r = 8, critic_lora_r = 8, actor_lora_scope = 'actor', critic_lora_scope = 'critic', actor_dropout = 0., critic_dropout = 0. ): super().__init__() self.actor_palm = rwkv self.critic_palm = critic_palm if not exists(self.critic_palm): self.critic_palm = copy.deepcopy(rwkv) self.actor_palm.set_dropout(actor_dropout) self.critic_palm.set_dropout(critic_dropout) self.actor_lora = actor_lora self.critic_lora = critic_lora self.actor_lora_scope = actor_lora_scope if actor_lora else None self.critic_lora_scope = critic_lora_scope if critic_lora else None if self.actor_lora: self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r) if self.critic_lora: self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r) self.pooled_values = pooled_values self.value_head = nn.Sequential( nn.Linear(rwkv.dim, 1), Rearrange('... 1 -> ...') ) nn.init.zeros_(self.value_head[0].bias) nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2)) def actor_parameters(self): if not self.actor_lora: return self.actor_palm.parameters() return [ *self.actor_palm.finetune_parameters(self.actor_lora_scope) ] def critic_parameters(self): if not self.actor_lora: return [*self.critic_palm.parameters(), *self.value_head.parameters()] return [ *self.critic_palm.finetune_parameters(self.critic_lora_scope), *self.value_head.parameters() ] @torch.no_grad() @eval_decorator def generate( self, state, max_seq_len, eos_token = None, return_values = False, **kwargs ): actions = self.actor_palm.generate( max_seq_len, prompt = state, eos_token = eos_token, finetune_scope = self.actor_lora_scope, use_tqdm = True, **kwargs ) sequence = torch.cat((state, actions), dim = -1) action_len = actions.shape[-1] state_len = state.shape[-1] prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0]) action_mask = ~prompt_mask mask = None if exists(eos_token): mask = ((sequence == eos_token).cumsum(dim = -1) == 0) mask = F.pad(mask, (1, -1), value = True) # include eos token action_mask &= mask action_logits, value = self.forward( sequence, mask = action_mask, return_values = return_values ) return PPOActionCriticReturn( actions, sequence, mask, prompt_mask, action_logits, value ) def forward( self, x, mask = None, return_values = True ): action_logits = self.actor_palm( x, finetune_scope = self.actor_lora_scope ) if not return_values: return action_logits, None critic_embeds = self.critic_palm( x, return_only_embedding = True, finetune_scope = self.critic_lora_scope ) if self.pooled_values: critic_embeds = shift(critic_embeds, shift = 1, dim = -2) critic_embeds = masked_mean(critic_embeds, mask, dim = 1) values = self.value_head(critic_embeds) return action_logits, values # data Memory = namedtuple('Memory', [ 'sequence', 'prompt_mask', 'mask', 'action_prob', 'action_log_prob', 'reward', 'value' ]) @beartype class ExperienceDataset(Dataset): def __init__( self, data: List[torch.Tensor], device = None ): super().__init__() self.data = data self.device = device def __len__(self): return self.data[0].shape[0] def __getitem__(self, ind): return tuple(map(lambda t: t[ind].to(self.device), self.data)) def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs): ds = ExperienceDataset(data, device = device) return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs) # helper functions def exists(val): return val is not None def default(val, d): return val if exists(val) else d def masked_normalize(t, eps = 1e-5, mask = None, dim = None): dim = default(dim, tuple(range(t.ndim))) kwargs = dict(dim = dim, keepdim = True) mean = masked_mean(t, mask = mask, **kwargs) mean_centered = t - mean var = masked_mean(mean_centered ** 2, mask = mask, **kwargs) return mean_centered * var.clamp(min = eps).rsqrt() def pad_sequence_fixed(sequences, *args, **kwargs): first_el = sequences[0] has_no_dimension = first_el.ndim == 0 # if no dimensions, add a single dimension if has_no_dimension: sequences = tuple(map(lambda t: t[None], sequences)) out = pad_sequence(sequences, *args, **kwargs) if has_no_dimension: out = rearrange(out, '... 1 -> ...') return out def log(t, eps = 1e-20): return torch.log(t.clamp(min = eps)) def log_prob(prob, indices): assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match' return log(prob.gather(-1, indices[..., None])).squeeze(-1) def shift(t, value = 0, shift = 1, dim = -1): zeros = (0, 0) * (-dim - 1) return F.pad(t, (*zeros, shift, -shift), value = value) def masked_entropy(prob, dim = -1, mask = None): entropies = (prob * log(prob)).sum(dim = -1) return masked_mean(entropies, mask = mask).mean() def masked_kl_div(prob1, prob2, mask = None): """ need to account for variable sequence lengths, therefore not using the built-in functional version """ kl_divs = (prob1 * (log(prob2) - log(prob1))).sum(dim = -1) if not exists(mask): return kl_divs.mean() return masked_mean(kl_divs, mask).mean() def clipped_value_loss(values, rewards, old_values, clip): value_clipped = old_values + (values - old_values).clamp(-clip, clip) value_loss_1 = (value_clipped.flatten() - rewards) ** 2 value_loss_2 = (values.flatten() - rewards) ** 2 return torch.mean(torch.max(value_loss_1, value_loss_2)) # rlhf trainer @beartype class RLHFTrainer(nn.Module): def __init__( self, args, accelerate_kwargs: dict = {} ): super().__init__() self.args = args self.accelerate = Accelerator(**accelerate_kwargs) # 加载 RWKV 模型 rwkv = RWKV(args) if len(args.load_sft_model) == 0: rank_zero_info(f"SFT must load model, please input ") exit(1) rank_zero_info(f"########## Loading {args.load_sft_model}... ##########") try: load_dict = torch.load(args.load_sft_model, map_location="cpu") except: rank_zero_info(f"Bad checkpoint {args.load_sft_model}") exit(1) if args.load_partial == 1: load_keys = load_dict.keys() for k in rwkv.state_dict(): if k not in load_keys: load_dict[k] = rwkv.state_dict()[k] rwkv.load_state_dict(load_dict) self.rwkv = rwkv # 使用 RWKV 初始化 actor_critic actor_critic = ActorCritic( rwkv = self.rwkv, actor_lora = args.actor_lora, critic_lora = args.critic_lora, actor_lora_r = args.actor_lora_r, critic_lora_r = args.critic_lora_r, pooled_values = args.critic_pooled_values, actor_dropout = args.actor_dropout, critic_dropout = args.critic_dropout ).to(self.rwkv.device) self.actor_critic = actor_critic # 加载 reward_model,并将 reward_model 设置为 evaluation 模式 reward_model = RewardModel(args) reward_model.load(args.load_rm_model) self.reward_model = reward_model.eval() # optimizers self.actor_optim = get_optimizer(actor_critic.actor_parameters(), lr = self.args.actor_lr, wd = self.args.actor_wd, betas = self.args.betas, eps = self.args.actor_adam_eps, use_lion = self.args.use_lion) self.critic_optim = get_optimizer(actor_critic.critic_parameters(), lr = self.args.critic_lr, wd = self.args.critic_wd, betas = self.args.betas, eps = self.args.critic_adam_eps, use_lion = self.args.use_lion) # prepare with accelerator ( self.actor_critic, self.reward_model, self.actor_optim, self.critic_optim ) = self.accelerate.prepare( self.actor_critic, self.reward_model, self.actor_optim, self.critic_optim ) def print(self, msg): return self.accelerate.print(msg) def save(self, filepath = './checkpoint.pt'): torch.save(self.actor_critic.state_dict(), filepath) def load(self, filepath = './checkpoint.pt'): state_dict = torch.load(filepath) self.actor_critic.load_state_dict(state_dict) @property def device(self): return self.accelerate.device @torch.no_grad() def generate( self, max_seq_len, *args, prompt, num_samples = 4, # sample 4 per prompt and select the one with highest reward **kwargs ): assert prompt.ndim == 1, 'only one prompt allowed at a time for now' prompt = repeat(prompt, 'n -> b n', b = num_samples) actor_critic = self.accelerate.unwrap_model(self.actor_critic) reward_model = self.accelerate.unwrap_model(self.reward_model) actor_critic.eval() ( actions, sequences, mask, prompt_mask, action_logits, _ ) = actor_critic.generate( prompt, *args, max_seq_len = max_seq_len, return_values = False, **kwargs ) rewards = reward_model( sequences, prompt_mask = prompt_mask, mask = mask, sample = True ) best_sequence_index = rewards.topk(1, dim = -1).indices best_sequence = sequences[best_sequence_index] best_sequence = rearrange(best_sequence, '1 ... -> ...') return best_sequence def learn( self, memories: Deque[Memory] ): # stack all data stored in the memories all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories))) # prepare dataloader for policy phase training dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device) self.actor_critic.train() # PPO training for _ in range(self.epochs): for ( sequences, prompt_masks, masks, old_action_probs, old_log_probs, rewards, old_values ) in dl: action_masks = ~prompt_masks & masks action_logits, values = self.actor_critic( sequences, mask = action_masks ) action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token action_len = old_log_probs.shape[-1] action_probs = action_logits.softmax(dim = -1) action_log_probs = log_prob(action_probs, sequences) action_log_probs = action_log_probs[:, -action_len:] # calculate entropies, taking into account which part of the sequence is actually an action entropies = masked_entropy(action_probs, mask = action_masks) # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not kl_div_loss = 0. if self.args.kl_div_loss_weight > 0: kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.args.kl_div_loss_weight # handle non-pooled values normalize_kwargs = dict() if old_values.ndim == 2: old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values)) old_values = old_values[:, -action_len:] values = values[:, -action_len:] rewards = rearrange(rewards, 'b -> b 1') normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:]) if values.ndim < rewards.ndim: values = rearrange(values, '... -> ... 1') # calculate clipped surrogate objective, classic PPO loss ratios = (action_log_probs - old_log_probs).exp() advantages = masked_normalize(rewards - old_values, **normalize_kwargs) if advantages.ndim == 1: advantages = rearrange(advantages, 'b -> b 1') surr1 = ratios * advantages surr2 = ratios.clamp(1 - self.args.eps_clip, 1 + self.args.eps_clip) * advantages policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies # combine losses loss = policy_loss.mean() + kl_div_loss # update actor self.accelerate.backward(loss) self.print(f'policy_loss: {loss.item():.3f}') if exists(self.args.max_norm): self.accelerator.clip_grad_norm_(self.actor_critic.actor_parameters(), self.args.max_norm) self.actor_optim.step() self.actor_optim.zero_grad() # calculate value loss and update value network separate from policy network value_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip) value_loss = value_loss.mean() self.print(f'critic_loss: {value_loss.item():.3f}') self.accelerate.backward(value_loss) if exists(self.args.max_norm): self.accelerator.clip_grad_norm_(self.actor_critic.critic_parameters(), self.args.max_norm) self.critic_optim.step() self.critic_optim.zero_grad() def train( self, num_episodes = 50000, max_timesteps = 500, update_timesteps = 5000, max_batch_size = 16, eos_token = None, temperature = 1. ): device = self.device time = 0 memories = deque([]) for eps in tqdm(range(num_episodes), desc = 'episodes'): for timestep in range(max_timesteps): time += 1 # select a bunch of random states (prompts) # and get the action (sampled sequence from palm as well as the action probs) # also calculate the reward using reward model and store rand_prompt_index = randrange(0, self.num_prompts) state = self.prompt_token_ids[rand_prompt_index] # remove padding from state state_mask = state != self.args.pad_value state = state[state_mask] # get predicted sequence ( actions, sequence, mask, prompt_mask, action_logits, value ) = self.actor_critic.generate( rearrange(state, 'n -> 1 n'), max_seq_len = self.args.ctx_len, eos_token = eos_token, temperature = temperature, return_values = True ) action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token action_prob = action_logits.softmax(dim = -1) action_len = actions.shape[-1] action_log_prob = log_prob(action_prob, sequence) action_log_prob = action_log_prob[:, -action_len:] actions = rearrange(actions, '1 ... -> ...') # get reward as given by supervised trained reward model sequence = torch.cat((state, actions), dim = 0) prompt_length = len(state) prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length sequence = rearrange(sequence, 'n -> 1 n') prompt_mask = rearrange(prompt_mask, 'n -> 1 n') mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device) reward = self.reward_model( sequence, prompt_mask = prompt_mask, mask = mask, sample = True ) detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...') # store memory for learning memories.append(Memory(*map(detach_to_cpu_, ( sequence, prompt_mask, mask, action_prob, action_log_prob, reward, value )))) # learn from the stored memories if time % update_timesteps == 0: self.learn(memories) memories.clear() print('rlhf training complete')