reward.py 2.9 KB
Newer Older
U
u010280923 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
import copy
from pathlib import Path

from tqdm import tqdm
from beartype import beartype
from beartype.typing import Tuple, Optional

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce

from src.rlhf.utils import masked_mean, gumbel_sample
U
u010280923 已提交
16
from src.rlhf.rwkv.model import RWKV
U
u010280923 已提交
17 18 19 20 21 22 23 24 25 26 27 28

# helper functions

def exists(val):
    return val is not None

# Reward Model - RWKV with a scalar head

@beartype
class RewardModel(nn.Module):
    def __init__(
        self,
U
u010280923 已提交
29
        rwkv: RWKV
U
u010280923 已提交
30 31 32 33
    ):
        super().__init__()

        # 用预训练模型初始化奖励模型
U
u010280923 已提交
34
        self.rwkv = rwkv
U
u010280923 已提交
35 36

        # 输出 token 向量的维度
U
u010280923 已提交
37
        dim = rwkv.args.n_embd
U
u010280923 已提交
38

U
u010280923 已提交
39
        # 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0
U
u010280923 已提交
40 41
        self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
        self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
U
u010280923 已提交
42 43 44 45 46 47

        # reward 得分计算
        self.pred_reward = nn.Sequential(
            nn.Linear(dim, 1),
            Rearrange('... 1 -> ...')   # 降维
        )
U
u010280923 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

    def load(self, path):
        path = Path(path)
        assert path.exists()
        self.load_state_dict(torch.load(str(path)))

    def finetune_parameters(self):
        return [
            *self.to_pred.parameters(),
            *self.rwkv.parameters()
        ]

    def forward(
        self,
        x,
        mask = None,
        prompt_mask = None,
U
u010280923 已提交
65
        prompt_lengths = None
U
u010280923 已提交
66 67
    ):

U
u010280923 已提交
68
        # prompt_mask 和 prompt_lengths 只能二选一
U
u010280923 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        assert not (exists(prompt_mask) and exists(prompt_lengths))

        # derive prompt mask from prompt lengths
        if exists(prompt_lengths):
            batch, seq_len = x.shape
            arange = torch.arange(seq_len, device = x.device)
            prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')

        # reward model should have an understanding of which section is prompt, and which section is response
        # 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值
        # 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选
        extra_embed = None
        if exists(prompt_mask):
            extra_embed = torch.where(
                rearrange(prompt_mask, 'b n -> b n 1'),
                self.prompt_embed,
                self.response_embed
            )

U
u010280923 已提交
88 89
        # 获得最后一个 token 的 embedding
        last_token_embeds = self.rwkv(
U
u010280923 已提交
90
            x,
U
u010280923 已提交
91 92
            state=None,
            extra_embed=extra_embed
U
u010280923 已提交
93 94 95
        )

        # 所有的 token 向量求平均,并输入到打分模块进行打分
U
u010280923 已提交
96 97 98 99 100 101 102 103
        try:
            pooled = masked_mean(last_token_embeds, mask, dim = 1)
        except:
            import ipdb
            ipdb.set_trace()
        reward = self.pred_reward(pooled)

        return reward