import math from pathlib import Path import copy from tqdm import tqdm from functools import partial from collections import deque, namedtuple from random import randrange from beartype import beartype from beartype.typing import List, Optional, Callable, Deque from einops import rearrange, repeat from einops.layers.torch import Rearrange import torch from torch import nn import torch.nn.functional as F from torch.optim import Adam from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.strategies import DeepSpeedStrategy from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from src.model import RWKV from src.rlhf.reward import RewardModel from src.rlhf.optimizer import get_optimizer from src.rlhf.utils import masked_mean, eval_decorator # actor critic PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ 'actions', 'sequence', 'mask', 'prompt_mask', 'action_logits', 'values' ]) @beartype class ActorCritic(nn.Module): def __init__( self, rwkv: RWKV, args, critic: Optional[RWKV] = None, pooled_values = False ): super().__init__() self.actor = copy.deepcopy(rwkv) self.critic = critic if not exists(self.critic): self.critic = copy.deepcopy(rwkv) self.pooled_values = pooled_values self.value_head = nn.Sequential( nn.Linear(args.n_embd, 1), Rearrange('... 1 -> ...') ) nn.init.zeros_(self.value_head[0].bias) nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2)) @torch.no_grad() @eval_decorator def generate( self, state, max_seq_len, eos_token = None, return_values = False, **kwargs ): # 产生一条 response,相当于采取了一次 action actions = self.actor.generate( max_seq_len, prompt = state, eos_token = eos_token, use_tqdm = True, **kwargs ) # 将 prompt (state) 和 response (action) 进行拼接 sequence = torch.cat((state, actions), dim = -1) action_len = actions.shape[-1] state_len = state.shape[-1] # 构建 prompt_mask (state_mask) 和 response_mask (action_mask) prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0]) action_mask = ~prompt_mask # 考虑 eos token mask = None if exists(eos_token): mask = ((sequence == eos_token).cumsum(dim = -1) == 0) mask = F.pad(mask, (1, -1), value = True) # include eos token action_mask &= mask # 将生成的 sequence 输入到 actor 中,得到 action_logits # 将生成的 sequence 输入到 critic 中,得到 value action_logits, value = self.forward( sequence, mask = action_mask, return_values = return_values ) return PPOActionCriticReturn( actions, sequence, mask, prompt_mask, action_logits, value ) def forward( self, x, mask = None, return_values = True ): action_logits, _ = self.actor( x, ppo_train = True ) if not return_values: return action_logits, None _, critic_embeds = self.critic( x, return_only_embedding = True, ppo_train = True ) if self.pooled_values: critic_embeds = shift(critic_embeds, shift = 1, dim = -2) critic_embeds = masked_mean(critic_embeds, mask, dim = 1) values = self.value_head(critic_embeds) return action_logits, values @beartype class ExperienceDataset(Dataset): def __init__( self, data: List[torch.Tensor], device = None ): super().__init__() self.data = data self.device = device def __len__(self): return self.data[0].shape[0] def __getitem__(self, ind): return tuple(map(lambda t: t[ind].to(self.device), self.data)) def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs): ds = ExperienceDataset(data, device = device) return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs) # helper functions def exists(val): return val is not None def default(val, d): return val if exists(val) else d def masked_normalize(t, eps = 1e-5, mask = None, dim = None): dim = default(dim, tuple(range(t.ndim))) kwargs = dict(dim = dim, keepdim = True) mean = masked_mean(t, mask = mask, **kwargs) mean_centered = t - mean var = masked_mean(mean_centered ** 2, mask = mask, **kwargs) return mean_centered * var.clamp(min = eps).rsqrt() def pad_sequence_fixed(sequences, *args, **kwargs): first_el = sequences[0] has_no_dimension = first_el.ndim == 0 # if no dimensions, add a single dimension if has_no_dimension: sequences = tuple(map(lambda t: t[None], sequences)) out = pad_sequence(sequences, *args, **kwargs) if has_no_dimension: out = rearrange(out, '... 1 -> ...') return out def log(t, eps = 1e-20): return torch.log(t.clamp(min = eps)) def log_prob(prob, indices): assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match' return log(prob.gather(-1, indices[..., None])).squeeze(-1) def shift(t, value = 0, shift = 1, dim = -1): zeros = (0, 0) * (-dim - 1) return F.pad(t, (*zeros, shift, -shift), value = value) def masked_entropy(prob, dim = -1, mask = None): entropies = (prob * log(prob)).sum(dim = -1) return masked_mean(entropies, mask = mask).mean() def masked_kl_div(prob1, prob2, mask = None): """ need to account for variable sequence lengths, therefore not using the built-in functional version """ kl_divs = (prob1 * (log(prob2) - log(prob1))).sum(dim = -1) if not exists(mask): return kl_divs.mean() return masked_mean(kl_divs, mask).mean() def clipped_value_loss(values, rewards, old_values, clip): value_clipped = old_values + (values - old_values).clamp(-clip, clip) value_loss_1 = (value_clipped.flatten() - rewards) ** 2 value_loss_2 = (values.flatten() - rewards) ** 2 return torch.mean(torch.max(value_loss_1, value_loss_2)) # rlhf @beartype class RLHF(pl.LightningModule): def __init__( self, args, rwkv: RWKV, reward_model: RewardModel ): super().__init__() self.args = args self.rwkv = rwkv # 使用 RWKV 初始化 actor_critic actor_critic = ActorCritic( rwkv = self.rwkv, args = self.args, pooled_values = args.critic_pooled_values ).to(self.rwkv.device) self.actor_critic = actor_critic # 将 reward_model 设置为 evaluation 模式 self.reward_model = reward_model.eval() def save(self, filepath = './checkpoint.pt'): torch.save(self.actor_critic.state_dict(), filepath) def load(self, filepath = './checkpoint.pt'): state_dict = torch.load(filepath) self.actor_critic.load_state_dict(state_dict) def configure_optimizers(self): args = self.args if args.layerwise_lr > 0: lr_1x = set() lr_2x = set() lr_3x = set() for n, p in self.named_parameters(): if "time_mix" in n: if args.my_pile_stage == 2: lr_2x.add(n) else: lr_1x.add(n) elif "time_decay" in n: if args.my_pile_stage == 2: lr_3x.add(n) else: lr_2x.add(n) elif "time_first" in n: lr_3x.add(n) else: lr_1x.add(n) lr_1x = sorted(list(lr_1x)) lr_2x = sorted(list(lr_2x)) lr_3x = sorted(list(lr_3x)) param_dict = {n: p for n, p in self.named_parameters()} if args.my_pile_stage == 2: optim_groups = [ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, ] else: optim_groups = [ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, ] else: optim_groups = [ {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, ] if self.deepspeed_offload: return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) @property def deepspeed_offload(self) -> bool: strategy = self.trainer.strategy if isinstance(strategy, DeepSpeedStrategy): cfg = strategy.config["zero_optimization"] return cfg.get("offload_optimizer") or cfg.get("offload_param") return False @torch.no_grad() def generate( self, max_seq_len, *args, prompt, num_samples = 4, # sample 4 per prompt and select the one with highest reward **kwargs ): ''' 未参与训练,仅推理时使用 ''' assert prompt.ndim == 1, 'only one prompt allowed at a time for now' prompt = repeat(prompt, 'n -> b n', b = num_samples) self.actor_critic.eval() ( actions, sequences, mask, prompt_mask, action_logits, _ ) = self.actor_critic.generate( prompt, *args, max_seq_len = max_seq_len, return_values = False, **kwargs ) rewards = self.reward_model( sequences, prompt_mask = prompt_mask, mask = mask, sample = True ) best_sequence_index = rewards.topk(1, dim = -1).indices best_sequence = sequences[best_sequence_index] best_sequence = rearrange(best_sequence, '1 ... -> ...') return best_sequence def training_step(self, batch, batch_idx): sequences, \ prompt_masks, \ masks, \ old_action_probs, \ old_log_probs, \ rewards, \ old_values = batch # PPO training action_masks = ~prompt_masks & masks action_logits, values = self.actor_critic( sequences, mask = action_masks ) action_logits = shift(action_logits, shift=1, dim=-2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token action_len = old_log_probs.shape[-1] action_probs = action_logits.softmax(dim = -1) action_log_probs = log_prob(action_probs, sequences) action_log_probs = action_log_probs[:, -action_len:] # calculate entropies, taking into account which part of the sequence is actually an action entropies = masked_entropy(action_probs, mask = action_masks) # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not kl_div_loss = 0. if self.args.kl_div_loss_weight > 0: kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.args.kl_div_loss_weight # handle non-pooled values normalize_kwargs = dict() if old_values.ndim == 2: old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values)) old_values = old_values[:, -action_len:] values = values[:, -action_len:] rewards = rearrange(rewards, 'b -> b 1') normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:]) if values.ndim < rewards.ndim: values = rearrange(values, '... -> ... 1') # calculate clipped surrogate objective, classic PPO loss ratios = (action_log_probs - old_log_probs).exp() advantages = masked_normalize(rewards - old_values, **normalize_kwargs) if advantages.ndim == 1: advantages = rearrange(advantages, 'b -> b 1') surr1 = ratios * advantages surr2 = ratios.clamp(1 - self.args.eps_clip, 1 + self.args.eps_clip) * advantages policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies # actor loss (也称为 policy loss, 是最终要使用模型的 loss) actor_loss = policy_loss.mean() + kl_div_loss # critic loss (也称为 value loss) # update value network separate from policy network critic_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip) critic_loss = critic_loss.mean() return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} def make_experience(self, prompts, eos_token=None, temperature=1): ''' 通过与 environment 交互产生训练数据 ''' device = self.device # select a bunch of random states (prompts) # and get the action (sampled sequence from rwkv as well as the action probs) # also calculate the reward using reward model and store # 随机挑选一条 prompt rand_prompt_index = randrange(0, len(prompts)) state = prompts[rand_prompt_index] # remove padding from state state_mask = state != self.args.pad_value state = state[state_mask] # get predicted sequence # 与 environment 进行交互,其中返回的: # action 是 response, # sequence 是 prompt + response, ( actions, sequence, mask, prompt_mask, action_logits, value ) = self.actor_critic.generate( rearrange(state, 'n -> 1 n'), max_seq_len = self.args.ctx_len, eos_token = eos_token, temperature = temperature, return_values = True ) action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token action_prob = action_logits.softmax(dim = -1) action_len = actions.shape[-1] action_log_prob = log_prob(action_prob, sequence) action_log_prob = action_log_prob[:, -action_len:] actions = rearrange(actions, '1 ... -> ...') # get reward as given by supervised trained reward model sequence = torch.cat((state, actions), dim = 0) prompt_length = len(state) prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length sequence = rearrange(sequence, 'n -> 1 n') prompt_mask = rearrange(prompt_mask, 'n -> 1 n') mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device) reward = self.reward_model( sequence, prompt_mask = prompt_mask, mask = mask, sample = True ) return ( sequence, prompt_mask, mask, action_prob, action_log_prob, reward, value )