reward.py 4.8 KB
Newer Older
U
u010280923 已提交
1 2 3 4 5 6 7 8 9 10
import copy
from pathlib import Path

from tqdm import tqdm
from beartype import beartype
from beartype.typing import Tuple, Optional

import torch
from torch import nn
import torch.nn.functional as F
U
u010280923 已提交
11 12
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info
U
u010280923 已提交
13 14 15 16 17

from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce

from src.rlhf.utils import masked_mean, gumbel_sample
U
u010280923 已提交
18 19
# from src.model import RWKV
from src.model import RWKV
U
u010280923 已提交
20 21 22 23 24 25

# helper functions

def exists(val):
    return val is not None

U
u010280923 已提交
26 27 28 29
# loss function
def loss_function(prefer_reward, alter_reward):
    return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward)))

U
u010280923 已提交
30 31 32
# Reward Model - RWKV with a scalar head

@beartype
U
u010280923 已提交
33
class RewardModel(pl.LightningModule):
U
u010280923 已提交
34
    def __init__(self, args):
U
u010280923 已提交
35 36
        super().__init__()

U
u010280923 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
        # 加载 RWKV 模型
        rwkv = RWKV(args)

        if len(args.load_model) == 0:
            rank_zero_info(f"SFT must load model, please input ")
            exit(1)

        rank_zero_info(f"########## Loading {args.load_model}... ##########")
        try:
            load_dict = torch.load(args.load_model, map_location="cpu")
        except:
            rank_zero_info(f"Bad checkpoint {args.load_model}")
            exit(1)

        if args.load_partial == 1:
            load_keys = load_dict.keys()
            for k in rwkv.state_dict():
                if k not in load_keys:
                    load_dict[k] = rwkv.state_dict()[k]
        rwkv.load_state_dict(load_dict)

U
u010280923 已提交
58
        # 用预训练模型初始化奖励模型
U
u010280923 已提交
59
        self.rwkv = rwkv
U
u010280923 已提交
60
        self.args = args
U
u010280923 已提交
61 62

        # 输出 token 向量的维度
U
u010280923 已提交
63
        dim = self.args.n_embd
U
u010280923 已提交
64

U
u010280923 已提交
65
        # 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0
U
u010280923 已提交
66 67
        self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
        self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
U
u010280923 已提交
68 69 70 71 72 73

        # reward 得分计算
        self.pred_reward = nn.Sequential(
            nn.Linear(dim, 1),
            Rearrange('... 1 -> ...')   # 降维
        )
U
u010280923 已提交
74 75 76 77 78 79 80 81 82 83 84

    def load(self, path):
        path = Path(path)
        assert path.exists()
        self.load_state_dict(torch.load(str(path)))

    def finetune_parameters(self):
        return [
            *self.to_pred.parameters(),
            *self.rwkv.parameters()
        ]
U
u010280923 已提交
85 86 87 88 89 90 91 92 93 94 95 96
    
    def configure_optimizers(self):
        # 论文中的参数:lr=1e-5, betas=(0.9, 0.95) 
        optimizer = torch.optim.Adam([
            {"rwkv_params": self.rwkv.parameters()},
            {"rm_params": self.parameters()}
        ], lr=self.args.lr_init, betas=self.args.betas)
        
        
        return optimizer

    def single_forward(
U
u010280923 已提交
97 98 99 100
        self,
        x,
        mask = None,
        prompt_mask = None,
U
u010280923 已提交
101
        prompt_lengths = None
U
u010280923 已提交
102 103
    ):

U
u010280923 已提交
104
        # prompt_mask 和 prompt_lengths 只能二选一
U
u010280923 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
        assert not (exists(prompt_mask) and exists(prompt_lengths))

        # derive prompt mask from prompt lengths
        if exists(prompt_lengths):
            batch, seq_len = x.shape
            arange = torch.arange(seq_len, device = x.device)
            prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')

        # reward model should have an understanding of which section is prompt, and which section is response
        # 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值
        # 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选
        extra_embed = None
        if exists(prompt_mask):
            extra_embed = torch.where(
                rearrange(prompt_mask, 'b n -> b n 1'),
                self.prompt_embed,
                self.response_embed
            )

U
u010280923 已提交
124 125
        # 获得最后一个 token 的 embedding
        last_token_embeds = self.rwkv(
U
u010280923 已提交
126
            x,
U
u010280923 已提交
127
            state=None,
U
u010280923 已提交
128 129
            extra_embed=extra_embed,
            rm_train=True
U
u010280923 已提交
130 131 132
        )

        # 所有的 token 向量求平均,并输入到打分模块进行打分
U
u010280923 已提交
133
        pooled = masked_mean(last_token_embeds, mask, dim = 1)
U
u010280923 已提交
134 135 136
        reward = self.pred_reward(pooled)

        return reward
U
u010280923 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
    
    def forward(self, prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask):
        prefer_reward = self.single_forward(prefer_x, prefer_x_prompt_mask)
        alter_reward = self.single_forward(alter_x, alter_x_prompt_mask)

        return prefer_reward, alter_reward
    
    def training_step(self, batch, batch_idx):
        prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask = batch
        prefer_reward, alter_reward = self(
            prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask)
        
        loss = loss_function(prefer_reward, alter_reward)

        return loss