import copy from pathlib import Path from tqdm import tqdm from beartype import beartype from beartype.typing import Tuple, Optional import torch from torch import nn import torch.nn.functional as F from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange, Reduce from src.rlhf.utils import masked_mean, gumbel_sample from src.rlhf.rwkv.model import RWKV # helper functions def exists(val): return val is not None # Reward Model - RWKV with a scalar head @beartype class RewardModel(nn.Module): def __init__( self, rwkv: RWKV ): super().__init__() # 用预训练模型初始化奖励模型 self.rwkv = rwkv # 输出 token 向量的维度 dim = rwkv.args.n_embd # 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0 self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)) self.response_embed = nn.Parameter(torch.zeros(1, 1, dim)) # reward 得分计算 self.pred_reward = nn.Sequential( nn.Linear(dim, 1), Rearrange('... 1 -> ...') # 降维 ) def load(self, path): path = Path(path) assert path.exists() self.load_state_dict(torch.load(str(path))) def finetune_parameters(self): return [ *self.to_pred.parameters(), *self.rwkv.parameters() ] def forward( self, x, mask = None, prompt_mask = None, prompt_lengths = None ): # prompt_mask 和 prompt_lengths 只能二选一 assert not (exists(prompt_mask) and exists(prompt_lengths)) # derive prompt mask from prompt lengths if exists(prompt_lengths): batch, seq_len = x.shape arange = torch.arange(seq_len, device = x.device) prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1') # reward model should have an understanding of which section is prompt, and which section is response # 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值 # 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选 extra_embed = None if exists(prompt_mask): extra_embed = torch.where( rearrange(prompt_mask, 'b n -> b n 1'), self.prompt_embed, self.response_embed ) # 获得最后一个 token 的 embedding last_token_embeds = self.rwkv( x, state=None, extra_embed=extra_embed ) # 所有的 token 向量求平均,并输入到打分模块进行打分 try: pooled = masked_mean(last_token_embeds, mask, dim = 1) except: import ipdb ipdb.set_trace() reward = self.pred_reward(pooled) return reward