import copy from pathlib import Path from tqdm import tqdm from beartype import beartype from beartype.typing import Tuple, Optional import torch from torch import nn import torch.nn.functional as F from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange, Reduce from src.rlhf.utils import masked_mean, gumbel_sample from src.model import RWKV # helper functions def exists(val): return val is not None # Reward Model - RWKV with a scalar head @beartype class RewardModel(nn.Module): def __init__( self, rwkv: RWKV, dropout = 0.1, num_binned_output = 0. ): super().__init__() # 用预训练模型初始化奖励模型 self.rwkv = copy.deepcopy(rwkv) self.rwkv.set_dropout(dropout) # todo(luxin) # 输出 token 向量的维度 dim = rwkv.dim # todo(luxin) # 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级 self.binned_output = num_binned_output > 1 # todo(luxin):prompt_embed 和 response_embed 都是初始化为全0?不应该有区分么 self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)) self.response_embed = nn.Parameter(torch.zeros(1, 1, dim)) # self.response_embed = nn.Parameter(torch.ones(1, 1, dim)) if self.binned_output: # 如果打分等级的类别数大于1,则为多分类问题 self.to_pred = nn.Linear(dim, num_binned_output) else: # 否则,直接是一个二分类问题 self.to_pred = nn.Sequential( nn.Linear(dim, 1, bias = False), Rearrange('... 1 -> ...') # 降维 ) def load(self, path): path = Path(path) assert path.exists() self.load_state_dict(torch.load(str(path))) def finetune_parameters(self): return [ *self.to_pred.parameters(), *self.rwkv.parameters() ] def forward( self, x, mask = None, prompt_mask = None, prompt_lengths = None, labels = None, sample = False, sample_temperature = 1. ): # prompt_mask 和 prompt_lengths 只能给1个 assert not (exists(prompt_mask) and exists(prompt_lengths)) # derive prompt mask from prompt lengths if exists(prompt_lengths): batch, seq_len = x.shape arange = torch.arange(seq_len, device = x.device) prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1') # reward model should have an understanding of which section is prompt, and which section is response # 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值 # 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选 extra_embed = None if exists(prompt_mask): extra_embed = torch.where( rearrange(prompt_mask, 'b n -> b n 1'), self.prompt_embed, self.response_embed ) # todo(luxin) get embeddings from rwkv embeds = self.rwkv( x, extra_embed = extra_embed, return_only_embedding = True ) # 所有的 token 向量求平均,并输入到打分模块进行打分 pooled = masked_mean(embeds, mask, dim = 1) pred = self.to_pred(pooled) if sample and self.binned_output: assert not exists(labels) pred = gumbel_sample(pred, temperature = sample_temperature, dim = -1) if not exists(labels): return pred # todo(luxin) 作者没有使用论文中考虑两个样本的 loss,而是单个样本的 loss if not self.binned_output: return F.mse_loss(pred, labels) return F.cross_entropy(pred, labels)