import copy from pathlib import Path from tqdm import tqdm from beartype import beartype from beartype.typing import Tuple, Optional import torch from torch import nn import torch.nn.functional as F import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange, Reduce from src.rlhf.utils import masked_mean, gumbel_sample # from src.model import RWKV from src.model import RWKV # helper functions def exists(val): return val is not None # loss function def loss_function(prefer_reward, alter_reward): return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward))) # Reward Model - RWKV with a scalar head @beartype class RewardModel(pl.LightningModule): def __init__(self, args): super().__init__() # 加载 RWKV 模型 rwkv = RWKV(args) if len(args.load_model) == 0: rank_zero_info(f"SFT must load model, please input ") exit(1) rank_zero_info(f"########## Loading {args.load_model}... ##########") try: load_dict = torch.load(args.load_model, map_location="cpu") except: rank_zero_info(f"Bad checkpoint {args.load_model}") exit(1) if args.load_partial == 1: load_keys = load_dict.keys() for k in rwkv.state_dict(): if k not in load_keys: load_dict[k] = rwkv.state_dict()[k] rwkv.load_state_dict(load_dict) # 用预训练模型初始化奖励模型 self.rwkv = rwkv self.args = args # 输出 token 向量的维度 dim = self.args.n_embd # 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0 self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)) self.response_embed = nn.Parameter(torch.zeros(1, 1, dim)) self.padding_embed = nn.Parameter(torch.zeros(1, 1, dim), requires_grad=False) # reward 得分计算 self.pred_reward = nn.Sequential( nn.Linear(dim, 1, bias=False), Rearrange('... 1 -> ...') # 降维 ) def load(self, path): path = Path(path) assert path.exists() self.load_state_dict(torch.load(str(path))) def configure_optimizers(self): # 论文中的参数: optimizer = torch.optim.Adam(self.parameters(), lr=1e-5, betas=(0.9, 0.95) ) # optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr_init, betas=self.args.betas) return optimizer def single_forward( self, x, mask = None, prompt_mask = None, prompt_lengths = None ): # prompt_mask 和 prompt_lengths 只能二选一 assert not (exists(prompt_mask) and exists(prompt_lengths)) # derive prompt mask from prompt lengths if exists(prompt_lengths): batch, seq_len = x.shape arange = torch.arange(seq_len, device = x.device) prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1') # reward model should have an understanding of which section is prompt, and which section is response # 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值 # 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选 extra_embed = None if exists(prompt_mask): extra_embed = torch.where( rearrange(prompt_mask, 'b n -> b n 1'), self.prompt_embed, self.response_embed ) # 获得最后一个 token 的 embedding last_token_embeds = self.rwkv( x, state=None, extra_embed=extra_embed, rm_train=True )[:, -1, :] # 所有的 token 向量求平均,并输入到打分模块进行打分 reward = self.pred_reward(last_token_embeds) return reward def forward(self, x_p, x_a, m_p, m_a): prefer_reward = self.single_forward(x_p, prompt_mask=m_p) alter_reward = self.single_forward(x_a, prompt_mask=m_a) return prefer_reward, alter_reward def training_step(self, batch, batch_idx): x_p, x_a, m_p, m_a = batch prefer_reward, alter_reward = self( x_p, x_a, m_p, m_a) loss = loss_function(prefer_reward, alter_reward) return loss