提交 5f0fb2fd 编写于 作者: U u010280923

add reward model

上级 6f1de051
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/3/1 11:54
# @Author : clong
# @File : train_sft.py
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
if __name__ == "__main__":
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
rank_zero_info("########## work in progress ##########")
########################################################################################################
#
# example: train a simple L12-D768 RWKV on dummy data
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "" --data_type "dummy" --vocab_size 0 \
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: train a simple L6-D512 RWKV from scratch on enwik8
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser = ArgumentParser()
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--dim_att", default=0, type=int)
parser.add_argument("--dim_ffn", default=0, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int)
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser.add_argument("--my_img_version", default=0, type=str)
parser.add_argument("--my_img_size", default=0, type=int)
parser.add_argument("--my_img_bit", default=0, type=int)
parser.add_argument("--my_img_clip", default='x', type=str)
parser.add_argument("--my_img_clip_scale", default=1, type=float)
parser.add_argument("--my_img_l1_scale", default=0, type=float)
parser.add_argument("--my_img_encoder", default='x', type=str)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser.add_argument("--my_sample_len", default=0, type=int)
parser.add_argument("--my_ffn_shift", default=1, type=int)
parser.add_argument("--my_att_shift", default=1, type=int)
parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
########################################################################################################
import os, warnings, math, datetime, sys, time
import numpy as np
import torch
from torch.utils.data import DataLoader
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
# os.environ["WDS_SHOW_SEED"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False
args.replace_sampler_ddp = False
args.logger = False
args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0
args.check_val_every_n_epoch = int(1e20)
args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
os.environ["RWKV_MY_TESTING"] = args.my_testing
if args.dim_att <= 0:
args.dim_att = args.n_embd
if args.dim_ffn <= 0:
args.dim_ffn = args.n_embd * 4
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir)
samples_per_epoch = args.epoch_steps * args.real_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len
rank_zero_info(
f"""
############################################################################
#
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
#
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
#
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
#
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
#
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
#
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
#
############################################################################
"""
)
rank_zero_info(str(vars(args)) + "\n")
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32":
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
os.environ["RWKV_JIT_ON"] = "1"
if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT_ON"] = "0"
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
if args.precision == "fp32":
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if "32" in args.precision:
args.precision = 32
elif args.precision == "fp16":
args.precision = 16
else:
args.precision = "bf16"
########################################################################################################
from src.trainer import train_callback, generate_init_weight
args.vocab_size = 20000
from src.model import RWKV
model = RWKV(args)
seq = torch.randint(0, 20000, (1, 100))
model(seq)
import ipdb
ipdb.set_trace()
# from palm_rlhf_pytorch.palm import PaLM
# from palm_rlhf_pytorch.reward import RewardModel
# from palm_rlhf_pytorch.ppo import RLHFTrainer, ActorCritic
import torch
from torch import nn
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# LoRA - https://arxiv.org/abs/2106.09685
class LoRA(nn.Module):
def __init__(
self,
dim,
dim_out,
r = 8,
alpha = None
):
super().__init__()
alpha = default(alpha, r)
self.scale = alpha / r
self.A = nn.Parameter(torch.randn(dim, r))
self.B = nn.Parameter(torch.zeros(r, dim_out))
@property
def weight(self):
return (self.A @ self.B) * self.scale
def forward(self, x):
return x @ self.weight
from torch.optim import AdamW, Adam
from lion_pytorch import Lion
def separate_weight_decayable_params(params):
wd_params, no_wd_params = [], []
for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params
def get_optimizer(
params,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.99),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,
use_lion = True,
**kwargs
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if group_wd_params and wd > 0:
wd_params, no_wd_params = separate_weight_decayable_params(params)
params = [
{'params': wd_params},
{'params': no_wd_params, 'weight_decay': 0},
]
if use_lion:
return Lion(params, lr = lr, betas = betas, weight_decay = wd)
if wd == 0:
return Adam(params, lr = lr, betas = betas, eps = eps)
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
import math
import copy
from pathlib import Path
from collections import namedtuple
from itertools import zip_longest
from tqdm import tqdm
from beartype import beartype
from beartype.typing import Tuple, Optional
import torch
from torch import einsum, nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from src.rlhf.utils import top_p, top_k, masked_mean, gumbel_sample, eval_decorator
from src.rlhf.lora import LoRA
# functions and decorators
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def identity(t, *args, **kwargs):
return t
def l2norm(t):
return F.normalize(t, dim = -1)
# normalization
# they use layernorm without bias, something that pytorch does not offer
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# residual
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
y = self.fn(x, **kwargs)
if not any([t.requires_grad for t in (x, y)]):
return x.add_(y)
return y + x
# rotary positional embedding w/ xpos
# https://arxiv.org/abs/2104.09864
# https://arxiv.org/abs/2212.10554v1
class RotaryEmbedding(nn.Module):
def __init__(self, dim, scale_base = 512, use_xpos = True):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.use_xpos = use_xpos
self.scale_base = scale_base
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.register_buffer('scale', scale)
def forward(self, seq_len, device):
t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
if not self.use_xpos:
return freqs, torch.ones(1, device = device)
power = (t - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
return freqs, scale
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(pos, t, scale = 1.):
return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame
class ParallelTransformerBlock(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
causal = True,
heads = 8,
qk_rmsnorm = False,
qk_scale = 8,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
use_xpos = True,
xpos_scale_base = 512
):
super().__init__()
self.norm = LayerNorm(dim)
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
self.qk_rmsnorm = qk_rmsnorm
if qk_rmsnorm:
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.heads = heads
self.scale = (dim_head ** -0.5) if not qk_rmsnorm else qk_scale
self.causal = causal
self.rotary_emb = RotaryEmbedding(dim_head, scale_base = xpos_scale_base, use_xpos = use_xpos and causal)
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
self.attn_dropout = nn.Dropout(attn_dropout)
# parallel feedforward tail
self.ff_out = nn.Sequential(
SwiGLU(),
nn.Dropout(ff_dropout),
nn.Linear(ff_inner_dim, dim, bias=False)
)
# for caching causal mask and rotary embeddings
self.register_buffer("mask", None, persistent=False)
self.register_buffer("pos_emb", None, persistent=False)
self.register_buffer("pos_emb_scale", None, persistent=False)
def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
def get_rotary_embedding(self, n, device):
if exists(self.pos_emb) and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n], self.pos_emb_scale[:n]
pos_emb, scale = self.rotary_emb(n, device=device)
self.register_buffer("pos_emb", pos_emb, persistent=False)
self.register_buffer("pos_emb_scale", scale, persistent=False)
return pos_emb, scale
def forward(
self,
x,
mask = None,
finetune_modules = None
):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device, h = x.shape[1], x.device, self.heads
# pre layernorm
x = self.norm(x)
# attention queries, keys, values, and feedforward inner
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
# finetune loras
lora_q = lora_k = lora_v = lora_o = None
if exists(finetune_modules):
lora_q, lora_k, lora_v, lora_o = finetune_modules
q = q + lora_q(x)
k = k + lora_k(x)
v = v + lora_v(x)
# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# https://arxiv.org/abs/1911.02150
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# qk rmsnorm
if self.qk_rmsnorm:
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# rotary embeddings with xpos decay for better length extrapolation
positions, scale = self.get_rotary_embedding(n, device)
q = apply_rotary_pos_emb(positions, q, scale)
k = apply_rotary_pos_emb(positions, k, scale ** -1)
# similarity
sim = einsum("b h i d, b j d -> b h i j", q, k) * self.scale
# key padding mask
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# causal mask
if self.causal:
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# attention
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# aggregate values
out = einsum("b h i j, b j d -> b h i d", attn, v)
# merge heads
out = rearrange(out, "b h n d -> b n (h d)")
attn_out = self.attn_out(out)
ff_out = self.ff_out(ff)
if exists(lora_o):
attn_out = attn_out + lora_o(out)
return attn_out + ff_out
# transformer
@beartype
class PaLM(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
depth,
causal = True,
dim_head = 64,
heads = 8,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
qk_rmsnorm = False,
lora_r = 8,
rotary_xpos_scale_base = 512,
finetune_scopes = tuple(),
cross_entropy_ignore_index = 0
):
super().__init__()
self.dim = dim
self.dim_head = dim_head
self.heads = heads
self.causal = causal
self.num_tokens = num_tokens
self.token_emb = nn.Embedding(num_tokens, dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
block = Residual(ParallelTransformerBlock(
dim = dim,
causal = causal,
dim_head = dim_head,
heads = heads,
qk_rmsnorm = qk_rmsnorm,
ff_mult = ff_mult,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
xpos_scale_base = rotary_xpos_scale_base
))
self.layers.append(block)
self.norm = LayerNorm(dim)
self.to_logits = nn.Linear(dim, num_tokens, bias=False)
self.to_logits.weight = self.token_emb.weight
nn.init.normal_(self.token_emb.weight, std=0.02)
# fine tuning related
self.lora_r = lora_r
self.finetune_modules = nn.ModuleDict({})
for scope in finetune_scopes:
self.add_finetune_params(scope)
# loss related
self.cross_entropy_ignore_index = cross_entropy_ignore_index
@property
def device(self):
return next(self.parameters()).device
def load(self, path):
path = Path(path)
assert path.exists()
self.load_state_dict(torch.load(str(path)))
def set_dropout(self, dropout):
for module in self.layers.modules():
if isinstance(module, nn.Dropout):
module.p = dropout
return self
def add_finetune_params(self, scope, lora_r = None):
assert scope not in self.finetune_modules, f'finetune scope {scope} already found'
dim, dim_head, heads, r, device = self.dim, self.dim_head, self.heads, default(lora_r, self.lora_r), self.device
q_inner_dim = heads * dim_head
kv_inner_dim = dim_head
lora_modules = nn.ModuleList([])
for _ in range(len(self.layers)):
lora_modules.append(nn.ModuleList([
LoRA(dim, q_inner_dim, r = r), # queries
LoRA(dim, kv_inner_dim, r = r), # keys
LoRA(dim, kv_inner_dim, r = r), # values
LoRA(q_inner_dim, dim, r = r) # wo
]))
self.finetune_modules[scope] = lora_modules.to(device)
def remove_finetune_params(self, scope):
assert scope in self.finetune_modules, f'finetune scope {scope} not found'
return self.finetune_modules.pop(scope)
@torch.no_grad()
def merge_finetune_params(self, scope):
""" in the case one wants to merge the fine-tuned actor LORA parameters and do multiple rounds of fine tuning off different reward models """
assert scope in self.finetune_modules, f'finetune scope {scope} not found'
lora_modules = self.finetune_modules.pop(scope)
for layer, (lora_q, lora_k, lora_v, lora_o) in zip(self.layers, lora_modules):
block = layer.fn
fused_attn_ff_weight = block.fused_attn_ff_proj.weight
attn_out_weight = block.attn_out.weight
fused_proj_out_dim = fused_attn_ff_weight.shape[0]
lora_qkv_weight, _ = pack([lora_q.weight, lora_k.weight, lora_v.weight], 'i *')
lora_qkv_weight = F.pad(lora_qkv_weight, (0, fused_proj_out_dim - lora_qkv_weight.shape[1]))
lora_qkv_weight = rearrange(lora_qkv_weight, 'i o -> o i')
lora_o_weight = rearrange(lora_o.weight, 'i o -> o i')
fused_attn_ff_weight.add_(lora_qkv_weight)
attn_out_weight.add_(lora_o_weight)
# researcher train palm parameters first
# before finetuning
def palm_parameters(self):
return set(self.parameters()) - set(self.finetune_modules.parameters())
def finetune_parameters(self, scope = 'default'):
assert scope in self.finetune_modules, f'finetune parameters of scope {scope} not found'
return self.finetune_modules[scope].parameters()
# generate function
@torch.no_grad()
@eval_decorator
def generate(
self,
seq_len,
prompt = None,
temperature = 1.,
filter_logits_fn = top_k,
filter_thres = 0.9,
pad_value = 0.,
eos_token = None,
return_seq_without_prompt = True,
use_tqdm = False,
**kwargs
):
if not exists(prompt):
prompt = torch.randint(0, self.num_tokens, (1, 1))
prompt = prompt.to(self.device)
return_seq_without_prompt = False
prompt, leading_dims = pack([prompt], '* n')
n, out = prompt.shape[-1], prompt.clone()
wrapper_fn = identity if not use_tqdm else tqdm
sample_num_times = max(1, seq_len - prompt.shape[-1])
for _ in wrapper_fn(range(sample_num_times)):
logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs)
logits, embeds = logits[:, -1], embeds[:, -1]
if exists(filter_logits_fn):
logits = filter_logits_fn(logits, thres = filter_thres)
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
out, _ = pack([out, sample], 'b *')
if exists(eos_token):
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, pad_value)
break
out, = unpack(out, leading_dims, '* n')
if not return_seq_without_prompt:
return out
return out[..., n:]
def forward(
self,
x,
return_loss = False,
disable_lora = False,
finetune_scope = None,
extra_embed = None,
return_only_embedding = False,
return_logits_with_embedding = False
):
if return_loss:
x, labels = x[:, :-1], x[:, 1:]
# mask if encoder
# treat any token ids that are negative as tokens to mask out - only needed if not autoregressive
if not self.causal:
mask = x >= 0
x = x.masked_fill(~mask, 0)
else:
mask = None
# get token embedding
x = self.token_emb(x)
if exists(extra_embed):
x = x + extra_embed
# finetune modules
finetune_modules = tuple()
if exists(finetune_scope) and not disable_lora:
assert finetune_scope in self.finetune_modules
finetune_modules = self.finetune_modules[finetune_scope]
# parallel attention / ff blocks, passing in finetuning loras
for layer, finetune_modules in zip_longest(self.layers, finetune_modules):
x = layer(x, mask = mask, finetune_modules = finetune_modules)
# final norm
embeds = self.norm(x)
if return_only_embedding:
return embeds
# to logits
logits = self.to_logits(x)
ret = (logits, embeds) if return_logits_with_embedding else logits
if not return_loss:
return ret
logits = rearrange(logits, 'b n c -> b c n')
return F.cross_entropy(logits, labels, ignore_index = self.cross_entropy_ignore_index)
import math
from pathlib import Path
import copy
from tqdm import tqdm
from functools import partial
from collections import deque, namedtuple
from random import randrange
from beartype import beartype
from beartype.typing import List, Optional, Callable, Deque
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from palm_rlhf_pytorch.palm import PaLM
from palm_rlhf_pytorch.reward import RewardModel
from palm_rlhf_pytorch.optimizer import get_optimizer
from palm_rlhf_pytorch.utils import masked_mean, eval_decorator
from accelerate import Accelerator
# actor critic - PaLM with lora
PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
'actions',
'sequence',
'mask',
'prompt_mask',
'action_logits',
'values'
])
@beartype
class ActorCritic(nn.Module):
def __init__(
self,
palm: PaLM,
critic_palm: Optional[PaLM] = None,
pooled_values = False,
actor_lora = True,
critic_lora = True,
actor_lora_r = 8,
critic_lora_r = 8,
actor_lora_scope = 'actor',
critic_lora_scope = 'critic',
actor_dropout = 0.,
critic_dropout = 0.
):
super().__init__()
self.actor_palm = palm
self.critic_palm = critic_palm
if not exists(self.critic_palm):
self.critic_palm = copy.deepcopy(palm)
self.actor_palm.set_dropout(actor_dropout)
self.critic_palm.set_dropout(critic_dropout)
self.actor_lora = actor_lora
self.critic_lora = critic_lora
self.actor_lora_scope = actor_lora_scope if actor_lora else None
self.critic_lora_scope = critic_lora_scope if critic_lora else None
if self.actor_lora:
self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)
if self.critic_lora:
self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)
self.pooled_values = pooled_values
self.value_head = nn.Sequential(
nn.Linear(palm.dim, 1),
Rearrange('... 1 -> ...')
)
nn.init.zeros_(self.value_head[0].bias)
nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))
def actor_parameters(self):
if not self.actor_lora:
return self.actor_palm.parameters()
return [
*self.actor_palm.finetune_parameters(self.actor_lora_scope)
]
def critic_parameters(self):
if not self.actor_lora:
return [*self.critic_palm.parameters(), *self.value_head.parameters()]
return [
*self.critic_palm.finetune_parameters(self.critic_lora_scope),
*self.value_head.parameters()
]
@torch.no_grad()
@eval_decorator
def generate(
self,
state,
max_seq_len,
eos_token = None,
return_values = False,
**kwargs
):
actions = self.actor_palm.generate(
max_seq_len,
prompt = state,
eos_token = eos_token,
finetune_scope = self.actor_lora_scope,
use_tqdm = True,
**kwargs
)
sequence = torch.cat((state, actions), dim = -1)
action_len = actions.shape[-1]
state_len = state.shape[-1]
prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len
prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0])
action_mask = ~prompt_mask
mask = None
if exists(eos_token):
mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
mask = F.pad(mask, (1, -1), value = True) # include eos token
action_mask &= mask
action_logits, value = self.forward(
sequence,
mask = action_mask,
return_values = return_values
)
return PPOActionCriticReturn(
actions,
sequence,
mask,
prompt_mask,
action_logits,
value
)
def forward(
self,
x,
mask = None,
return_values = True
):
action_logits = self.actor_palm(
x,
finetune_scope = self.actor_lora_scope
)
if not return_values:
return action_logits, None
critic_embeds = self.critic_palm(
x,
return_only_embedding = True,
finetune_scope = self.critic_lora_scope
)
if self.pooled_values:
critic_embeds = shift(critic_embeds, shift = 1, dim = -2)
critic_embeds = masked_mean(critic_embeds, mask, dim = 1)
values = self.value_head(critic_embeds)
return action_logits, values
# data
Memory = namedtuple('Memory', [
'sequence',
'prompt_mask',
'mask',
'action_prob',
'action_log_prob',
'reward',
'value'
])
@beartype
class ExperienceDataset(Dataset):
def __init__(
self,
data: List[torch.Tensor],
device = None
):
super().__init__()
self.data = data
self.device = device
def __len__(self):
return self.data[0].shape[0]
def __getitem__(self, ind):
return tuple(map(lambda t: t[ind].to(self.device), self.data))
def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs):
ds = ExperienceDataset(data, device = device)
return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs)
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def masked_normalize(t, eps = 1e-5, mask = None, dim = None):
dim = default(dim, tuple(range(t.ndim)))
kwargs = dict(dim = dim, keepdim = True)
mean = masked_mean(t, mask = mask, **kwargs)
mean_centered = t - mean
var = masked_mean(mean_centered ** 2, mask = mask, **kwargs)
return mean_centered * var.clamp(min = eps).rsqrt()
def pad_sequence_fixed(sequences, *args, **kwargs):
first_el = sequences[0]
has_no_dimension = first_el.ndim == 0
# if no dimensions, add a single dimension
if has_no_dimension:
sequences = tuple(map(lambda t: t[None], sequences))
out = pad_sequence(sequences, *args, **kwargs)
if has_no_dimension:
out = rearrange(out, '... 1 -> ...')
return out
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def log_prob(prob, indices):
assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match'
return log(prob.gather(-1, indices[..., None])).squeeze(-1)
def shift(t, value = 0, shift = 1, dim = -1):
zeros = (0, 0) * (-dim - 1)
return F.pad(t, (*zeros, shift, -shift), value = value)
def masked_entropy(prob, dim = -1, mask = None):
entropies = (prob * log(prob)).sum(dim = -1)
return masked_mean(entropies, mask = mask).mean()
def masked_kl_div(prob1, prob2, mask = None):
"""
need to account for variable sequence lengths, therefore not using the built-in functional version
"""
kl_divs = (prob1 * (log(prob2) - log(prob1))).sum(dim = -1)
if not exists(mask):
return kl_divs.mean()
return masked_mean(kl_divs, mask).mean()
def clipped_value_loss(values, rewards, old_values, clip):
value_clipped = old_values + (values - old_values).clamp(-clip, clip)
value_loss_1 = (value_clipped.flatten() - rewards) ** 2
value_loss_2 = (values.flatten() - rewards) ** 2
return torch.mean(torch.max(value_loss_1, value_loss_2))
# rlhf trainer
@beartype
class RLHFTrainer(nn.Module):
def __init__(
self,
*,
prompts: Optional[List[str]] = None,
prompts_path: Optional[str] = None,
prompt_token_ids: Optional[torch.Tensor] = None,
tokenizer: Callable = None,
palm: PaLM,
reward_model: RewardModel,
actor_critic: Optional[ActorCritic] = None,
actor_lr = 1e-4,
critic_lr = 1e-4,
actor_wd = 0.,
critic_wd = 0.,
actor_adam_eps = 1e-7,
critic_adam_eps = 1e-7,
actor_lora = True,
critic_lora = True,
actor_lora_r = 8,
critic_lora_r = 8,
critic_pooled_values = True,
actor_dropout = 0.,
critic_dropout = 0.,
betas = (0.9, 0.999),
max_norm = None,
eps_clip = 0.2,
value_clip = 0.4,
beta_s = .01,
pad_value = 0.,
minibatch_size = 16,
epochs = 1,
kl_div_loss_weight = 0.1, # between old action probs and new action probs - not sure what the right value is
accelerate_kwargs: dict = {},
use_lion = False
):
super().__init__()
self.accelerate = Accelerator(**accelerate_kwargs)
# take care of prompts -> token ids
assert (exists(prompts) + exists(prompts_path) + exists(prompt_token_ids)) == 1
if exists(prompts_path):
path = Path(prompts_path)
prompts = path.read_text().split('\n')
if exists(prompts):
assert len(prompts) > 0, 'no prompts'
assert exists(tokenizer), 'tokenizer must be passed in if raw text prompts are given'
prompt_token_ids = tokenizer(prompts)
self.pad_value = pad_value # token pad value
self.num_prompts = prompt_token_ids.shape[0]
self.register_buffer('prompt_token_ids', prompt_token_ids)
# models
self.palm = palm
if not exists(actor_critic):
actor_critic = ActorCritic(
palm = palm,
actor_lora = actor_lora,
critic_lora = critic_lora,
actor_lora_r = actor_lora_r,
critic_lora_r = critic_lora_r,
pooled_values = critic_pooled_values,
actor_dropout = actor_dropout,
critic_dropout = critic_dropout
).to(palm.device)
self.actor_critic = actor_critic
self.reward_model = reward_model.eval()
# train hyperparameters
self.epochs = epochs
self.minibatch_size = minibatch_size
self.max_norm = max_norm
self.kl_div_loss_weight = kl_div_loss_weight
# optimizers
self.actor_optim = get_optimizer(actor_critic.actor_parameters(), lr = actor_lr, wd = actor_wd, betas = betas, eps = actor_adam_eps, use_lion = use_lion)
self.critic_optim = get_optimizer(actor_critic.critic_parameters(), lr = critic_lr, wd = critic_wd, betas = betas, eps = critic_adam_eps, use_lion = use_lion)
# ppo hyperparams
self.eps_clip = eps_clip
self.value_clip = value_clip
self.beta_s = beta_s
# prepare with accelerator
(
self.actor_critic,
self.reward_model,
self.actor_optim,
self.critic_optim
) = self.accelerate.prepare(
self.actor_critic,
self.reward_model,
self.actor_optim,
self.critic_optim
)
def print(self, msg):
return self.accelerate.print(msg)
def save(self, filepath = './checkpoint.pt'):
torch.save(self.actor_critic.state_dict(), filepath)
def load(self, filepath = './checkpoint.pt'):
state_dict = torch.load(filepath)
self.actor_critic.load_state_dict(state_dict)
@property
def device(self):
return self.accelerate.device
@torch.no_grad()
def generate(
self,
max_seq_len,
*args,
prompt,
num_samples = 4, # sample 4 per prompt and select the one with highest reward
**kwargs
):
assert prompt.ndim == 1, 'only one prompt allowed at a time for now'
prompt = repeat(prompt, 'n -> b n', b = num_samples)
actor_critic = self.accelerate.unwrap_model(self.actor_critic)
reward_model = self.accelerate.unwrap_model(self.reward_model)
actor_critic.eval()
(
actions,
sequences,
mask,
prompt_mask,
action_logits,
_
) = actor_critic.generate(
prompt,
*args,
max_seq_len = max_seq_len,
return_values = False,
**kwargs
)
rewards = reward_model(
sequences,
prompt_mask = prompt_mask,
mask = mask,
sample = True
)
best_sequence_index = rewards.topk(1, dim = -1).indices
best_sequence = sequences[best_sequence_index]
best_sequence = rearrange(best_sequence, '1 ... -> ...')
return best_sequence
def learn(
self,
memories: Deque[Memory]
):
# stack all data stored in the memories
all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories)))
# prepare dataloader for policy phase training
dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device)
self.actor_critic.train()
# PPO training
for _ in range(self.epochs):
for (
sequences,
prompt_masks,
masks,
old_action_probs,
old_log_probs,
rewards,
old_values
) in dl:
action_masks = ~prompt_masks & masks
action_logits, values = self.actor_critic(
sequences,
mask = action_masks
)
action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_len = old_log_probs.shape[-1]
action_probs = action_logits.softmax(dim = -1)
action_log_probs = log_prob(action_probs, sequences)
action_log_probs = action_log_probs[:, -action_len:]
# calculate entropies, taking into account which part of the sequence is actually an action
entropies = masked_entropy(action_probs, mask = action_masks)
# calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not
kl_div_loss = 0.
if self.kl_div_loss_weight > 0:
kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.kl_div_loss_weight
# handle non-pooled values
normalize_kwargs = dict()
if old_values.ndim == 2:
old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values))
old_values = old_values[:, -action_len:]
values = values[:, -action_len:]
rewards = rearrange(rewards, 'b -> b 1')
normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:])
if values.ndim < rewards.ndim:
values = rearrange(values, '... -> ... 1')
# calculate clipped surrogate objective, classic PPO loss
ratios = (action_log_probs - old_log_probs).exp()
advantages = masked_normalize(rewards - old_values, **normalize_kwargs)
if advantages.ndim == 1:
advantages = rearrange(advantages, 'b -> b 1')
surr1 = ratios * advantages
surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages
policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropies
# combine losses
loss = policy_loss.mean() + kl_div_loss
# update actor
self.accelerate.backward(loss)
self.print(f'policy_loss: {loss.item():.3f}')
if exists(self.max_norm):
self.accelerator.clip_grad_norm_(self.actor_critic.actor_parameters(), self.max_norm)
self.actor_optim.step()
self.actor_optim.zero_grad()
# calculate value loss and update value network separate from policy network
value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip)
value_loss = value_loss.mean()
self.print(f'critic_loss: {value_loss.item():.3f}')
self.accelerate.backward(value_loss)
if exists(self.max_norm):
self.accelerator.clip_grad_norm_(self.actor_critic.critic_parameters(), self.max_norm)
self.critic_optim.step()
self.critic_optim.zero_grad()
def train(
self,
num_episodes = 50000,
max_timesteps = 500,
update_timesteps = 5000,
max_batch_size = 16,
max_seq_len = 2048,
eos_token = None,
temperature = 1.
):
device = self.device
time = 0
memories = deque([])
for eps in tqdm(range(num_episodes), desc = 'episodes'):
for timestep in range(max_timesteps):
time += 1
# select a bunch of random states (prompts)
# and get the action (sampled sequence from palm as well as the action probs)
# also calculate the reward using reward model and store
rand_prompt_index = randrange(0, self.num_prompts)
state = self.prompt_token_ids[rand_prompt_index]
# remove padding from state
state_mask = state != self.pad_value
state = state[state_mask]
# get predicted sequence
(
actions,
sequence,
mask,
prompt_mask,
action_logits,
value
) = self.actor_critic.generate(
rearrange(state, 'n -> 1 n'),
max_seq_len = max_seq_len,
eos_token = eos_token,
temperature = temperature,
return_values = True
)
action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_prob = action_logits.softmax(dim = -1)
action_len = actions.shape[-1]
action_log_prob = log_prob(action_prob, sequence)
action_log_prob = action_log_prob[:, -action_len:]
actions = rearrange(actions, '1 ... -> ...')
# get reward as given by supervised trained reward model
sequence = torch.cat((state, actions), dim = 0)
prompt_length = len(state)
prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length
sequence = rearrange(sequence, 'n -> 1 n')
prompt_mask = rearrange(prompt_mask, 'n -> 1 n')
mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device)
reward = self.reward_model(
sequence,
prompt_mask = prompt_mask,
mask = mask,
sample = True
)
detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...')
# store memory for learning
memories.append(Memory(*map(detach_to_cpu_, (
sequence,
prompt_mask,
mask,
action_prob,
action_log_prob,
reward,
value
))))
# learn from the stored memories
if time % update_timesteps == 0:
self.learn(memories)
memories.clear()
print('rlhf training complete')
import copy
from pathlib import Path
from tqdm import tqdm
from beartype import beartype
from beartype.typing import Tuple, Optional
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from src.rlhf.utils import masked_mean, gumbel_sample
from src.model import RWKV
# helper functions
def exists(val):
return val is not None
# Reward Model - RWKV with a scalar head
@beartype
class RewardModel(nn.Module):
def __init__(
self,
rwkv: RWKV,
dropout = 0.1,
num_binned_output = 0.
):
super().__init__()
# 用预训练模型初始化奖励模型
self.rwkv = copy.deepcopy(rwkv)
self.rwkv.set_dropout(dropout) # todo(luxin)
# 输出 token 向量的维度
dim = rwkv.dim # todo(luxin)
# 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级
self.binned_output = num_binned_output > 1
# todo(luxin):prompt_embed 和 response_embed 都是初始化为全0?不应该有区分么
self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
# self.response_embed = nn.Parameter(torch.ones(1, 1, dim))
if self.binned_output:
# 如果打分等级的类别数大于1,则为多分类问题
self.to_pred = nn.Linear(dim, num_binned_output)
else:
# 否则,直接是一个二分类问题
self.to_pred = nn.Sequential(
nn.Linear(dim, 1, bias = False),
Rearrange('... 1 -> ...') # 降维
)
def load(self, path):
path = Path(path)
assert path.exists()
self.load_state_dict(torch.load(str(path)))
def finetune_parameters(self):
return [
*self.to_pred.parameters(),
*self.rwkv.parameters()
]
def forward(
self,
x,
mask = None,
prompt_mask = None,
prompt_lengths = None,
labels = None,
sample = False,
sample_temperature = 1.
):
# prompt_mask 和 prompt_lengths 只能给1个
assert not (exists(prompt_mask) and exists(prompt_lengths))
# derive prompt mask from prompt lengths
if exists(prompt_lengths):
batch, seq_len = x.shape
arange = torch.arange(seq_len, device = x.device)
prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')
# reward model should have an understanding of which section is prompt, and which section is response
# 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值
# 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选
extra_embed = None
if exists(prompt_mask):
extra_embed = torch.where(
rearrange(prompt_mask, 'b n -> b n 1'),
self.prompt_embed,
self.response_embed
)
# todo(luxin) get embeddings from rwkv
embeds = self.rwkv(
x,
extra_embed = extra_embed,
return_only_embedding = True
)
# 所有的 token 向量求平均,并输入到打分模块进行打分
pooled = masked_mean(embeds, mask, dim = 1)
pred = self.to_pred(pooled)
if sample and self.binned_output:
assert not exists(labels)
pred = gumbel_sample(pred, temperature = sample_temperature, dim = -1)
if not exists(labels):
return pred
# todo(luxin) 作者没有使用论文中考虑两个样本的 loss,而是单个样本的 loss
if not self.binned_output:
return F.mse_loss(pred, labels)
return F.cross_entropy(pred, labels)
import math
import torch
from torch import einsum, nn
import torch.nn.functional as F
from einops import rearrange
def exists(val):
return val is not None
# decorators
def eval_decorator(fn):
def inner(self, *args, **kwargs):
was_training = self.training
self.eval()
out = fn(self, *args, **kwargs)
self.train(was_training)
return out
return inner
# tensor helpers
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def masked_mean(seq, mask = None, dim = 1, keepdim = False):
if not exists(mask):
return seq.mean(dim = dim)
if seq.ndim == 3:
mask = rearrange(mask, 'b n -> b n 1')
masked_seq = seq.masked_fill(~mask, 0.)
numer = masked_seq.sum(dim = dim, keepdim = keepdim)
denom = mask.sum(dim = dim, keepdim = keepdim)
masked_mean = numer / denom.clamp(min = 1e-3)
masked_mean = masked_mean.masked_fill(denom == 0, 0.)
return masked_mean
# sampling helpers
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > (1 - thres)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
def top_k(logits, thres = 0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
import torch
from src.rlhf.reward import RewardModel
from src.model import RWKV
rwkv_model = RWKV()
reward_model = RewardModel(
rwkv_model,
num_binned_output = 5 # 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级
)
# mock data
seq = torch.randint(0, 20000, (1, 100))
# prompt_mask = torch.zeros(1, 100).bool() # which part of the sequence is prompt, which part is response
prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1)
labels = torch.randint(0, 5, (1,))
# train
loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
loss.backward()
# after much training
reward = reward_model(seq, prompt_mask = prompt_mask)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册