reward.py 3.9 KB
Newer Older
U
u010280923 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
import copy
from pathlib import Path

from tqdm import tqdm
from beartype import beartype
from beartype.typing import Tuple, Optional

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce

from src.rlhf.utils import masked_mean, gumbel_sample
from src.model import RWKV

# helper functions

def exists(val):
    return val is not None

# Reward Model - RWKV with a scalar head

@beartype
class RewardModel(nn.Module):
    def __init__(
        self,
        rwkv: RWKV,
        dropout = 0.1,
        num_binned_output = 0.
    ):
        super().__init__()

        # 用预训练模型初始化奖励模型
        self.rwkv = copy.deepcopy(rwkv)
        self.rwkv.set_dropout(dropout)  # todo(luxin)

        # 输出 token 向量的维度
        dim = rwkv.dim  # todo(luxin)

        # 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级
        self.binned_output = num_binned_output > 1

        # todo(luxin):prompt_embed 和 response_embed 都是初始化为全0?不应该有区分么
        self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
        self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
        # self.response_embed = nn.Parameter(torch.ones(1, 1, dim))

        if self.binned_output:
            # 如果打分等级的类别数大于1,则为多分类问题
            self.to_pred = nn.Linear(dim, num_binned_output)
        else:
            # 否则,直接是一个二分类问题
            self.to_pred = nn.Sequential(
                nn.Linear(dim, 1, bias = False),
                Rearrange('... 1 -> ...')   # 降维
            )

    def load(self, path):
        path = Path(path)
        assert path.exists()
        self.load_state_dict(torch.load(str(path)))

    def finetune_parameters(self):
        return [
            *self.to_pred.parameters(),
            *self.rwkv.parameters()
        ]

    def forward(
        self,
        x,
        mask = None,
        prompt_mask = None,
        prompt_lengths = None,
        labels = None,
        sample = False,
        sample_temperature = 1.
    ):

        # prompt_mask 和 prompt_lengths 只能给1个
        assert not (exists(prompt_mask) and exists(prompt_lengths))

        # derive prompt mask from prompt lengths
        if exists(prompt_lengths):
            batch, seq_len = x.shape
            arange = torch.arange(seq_len, device = x.device)
            prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')

        # reward model should have an understanding of which section is prompt, and which section is response
        # 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值
        # 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选
        extra_embed = None
        if exists(prompt_mask):
            extra_embed = torch.where(
                rearrange(prompt_mask, 'b n -> b n 1'),
                self.prompt_embed,
                self.response_embed
            )

        # todo(luxin) get embeddings from rwkv
        embeds = self.rwkv(
            x,
            extra_embed = extra_embed,
            return_only_embedding = True
        )

        # 所有的 token 向量求平均,并输入到打分模块进行打分
        pooled = masked_mean(embeds, mask, dim = 1)
        pred = self.to_pred(pooled)

        if sample and self.binned_output:
            assert not exists(labels)
            pred = gumbel_sample(pred, temperature = sample_temperature, dim = -1)

        if not exists(labels):
            return pred

        # todo(luxin) 作者没有使用论文中考虑两个样本的 loss,而是单个样本的 loss
        if not self.binned_output:
            return F.mse_loss(pred, labels)

        return F.cross_entropy(pred, labels)