提交 5f0fb2fd 编写于 作者: U u010280923

add reward model

上级 6f1de051
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/3/1 11:54
# @Author : clong
# @File : train_sft.py
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
if __name__ == "__main__":
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
rank_zero_info("########## work in progress ##########")
########################################################################################################
#
# example: train a simple L12-D768 RWKV on dummy data
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "" --data_type "dummy" --vocab_size 0 \
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: train a simple L6-D512 RWKV from scratch on enwik8
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser = ArgumentParser()
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--dim_att", default=0, type=int)
parser.add_argument("--dim_ffn", default=0, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int)
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser.add_argument("--my_img_version", default=0, type=str)
parser.add_argument("--my_img_size", default=0, type=int)
parser.add_argument("--my_img_bit", default=0, type=int)
parser.add_argument("--my_img_clip", default='x', type=str)
parser.add_argument("--my_img_clip_scale", default=1, type=float)
parser.add_argument("--my_img_l1_scale", default=0, type=float)
parser.add_argument("--my_img_encoder", default='x', type=str)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser.add_argument("--my_sample_len", default=0, type=int)
parser.add_argument("--my_ffn_shift", default=1, type=int)
parser.add_argument("--my_att_shift", default=1, type=int)
parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
########################################################################################################
import os, warnings, math, datetime, sys, time
import numpy as np
import torch
from torch.utils.data import DataLoader
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
# os.environ["WDS_SHOW_SEED"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False
args.replace_sampler_ddp = False
args.logger = False
args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0
args.check_val_every_n_epoch = int(1e20)
args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
os.environ["RWKV_MY_TESTING"] = args.my_testing
if args.dim_att <= 0:
args.dim_att = args.n_embd
if args.dim_ffn <= 0:
args.dim_ffn = args.n_embd * 4
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir)
samples_per_epoch = args.epoch_steps * args.real_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len
rank_zero_info(
f"""
############################################################################
#
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
#
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
#
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
#
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
#
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
#
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
#
############################################################################
"""
)
rank_zero_info(str(vars(args)) + "\n")
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32":
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
os.environ["RWKV_JIT_ON"] = "1"
if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT_ON"] = "0"
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
if args.precision == "fp32":
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if "32" in args.precision:
args.precision = 32
elif args.precision == "fp16":
args.precision = 16
else:
args.precision = "bf16"
########################################################################################################
from src.trainer import train_callback, generate_init_weight
args.vocab_size = 20000
from src.model import RWKV
model = RWKV(args)
seq = torch.randint(0, 20000, (1, 100))
model(seq)
import ipdb
ipdb.set_trace()
# from palm_rlhf_pytorch.palm import PaLM
# from palm_rlhf_pytorch.reward import RewardModel
# from palm_rlhf_pytorch.ppo import RLHFTrainer, ActorCritic
import torch
from torch import nn
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# LoRA - https://arxiv.org/abs/2106.09685
class LoRA(nn.Module):
def __init__(
self,
dim,
dim_out,
r = 8,
alpha = None
):
super().__init__()
alpha = default(alpha, r)
self.scale = alpha / r
self.A = nn.Parameter(torch.randn(dim, r))
self.B = nn.Parameter(torch.zeros(r, dim_out))
@property
def weight(self):
return (self.A @ self.B) * self.scale
def forward(self, x):
return x @ self.weight
from torch.optim import AdamW, Adam
from lion_pytorch import Lion
def separate_weight_decayable_params(params):
wd_params, no_wd_params = [], []
for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params
def get_optimizer(
params,
lr = 1e-4,
wd = 1e-2,
betas = (0.9, 0.99),
eps = 1e-8,
filter_by_requires_grad = False,
group_wd_params = True,
use_lion = True,
**kwargs
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if group_wd_params and wd > 0:
wd_params, no_wd_params = separate_weight_decayable_params(params)
params = [
{'params': wd_params},
{'params': no_wd_params, 'weight_decay': 0},
]
if use_lion:
return Lion(params, lr = lr, betas = betas, weight_decay = wd)
if wd == 0:
return Adam(params, lr = lr, betas = betas, eps = eps)
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
import math
import copy
from pathlib import Path
from collections import namedtuple
from itertools import zip_longest
from tqdm import tqdm
from beartype import beartype
from beartype.typing import Tuple, Optional
import torch
from torch import einsum, nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from src.rlhf.utils import top_p, top_k, masked_mean, gumbel_sample, eval_decorator
from src.rlhf.lora import LoRA
# functions and decorators
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def identity(t, *args, **kwargs):
return t
def l2norm(t):
return F.normalize(t, dim = -1)
# normalization
# they use layernorm without bias, something that pytorch does not offer
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# residual
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
y = self.fn(x, **kwargs)
if not any([t.requires_grad for t in (x, y)]):
return x.add_(y)
return y + x
# rotary positional embedding w/ xpos
# https://arxiv.org/abs/2104.09864
# https://arxiv.org/abs/2212.10554v1
class RotaryEmbedding(nn.Module):
def __init__(self, dim, scale_base = 512, use_xpos = True):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.use_xpos = use_xpos
self.scale_base = scale_base
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.register_buffer('scale', scale)
def forward(self, seq_len, device):
t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
if not self.use_xpos:
return freqs, torch.ones(1, device = device)
power = (t - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
return freqs, scale
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(pos, t, scale = 1.):
return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame
class ParallelTransformerBlock(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
causal = True,
heads = 8,
qk_rmsnorm = False,
qk_scale = 8,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
use_xpos = True,
xpos_scale_base = 512
):
super().__init__()
self.norm = LayerNorm(dim)
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
self.qk_rmsnorm = qk_rmsnorm
if qk_rmsnorm:
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.heads = heads
self.scale = (dim_head ** -0.5) if not qk_rmsnorm else qk_scale
self.causal = causal
self.rotary_emb = RotaryEmbedding(dim_head, scale_base = xpos_scale_base, use_xpos = use_xpos and causal)
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
self.attn_dropout = nn.Dropout(attn_dropout)
# parallel feedforward tail
self.ff_out = nn.Sequential(
SwiGLU(),
nn.Dropout(ff_dropout),
nn.Linear(ff_inner_dim, dim, bias=False)
)
# for caching causal mask and rotary embeddings
self.register_buffer("mask", None, persistent=False)
self.register_buffer("pos_emb", None, persistent=False)
self.register_buffer("pos_emb_scale", None, persistent=False)
def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
def get_rotary_embedding(self, n, device):
if exists(self.pos_emb) and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n], self.pos_emb_scale[:n]
pos_emb, scale = self.rotary_emb(n, device=device)
self.register_buffer("pos_emb", pos_emb, persistent=False)
self.register_buffer("pos_emb_scale", scale, persistent=False)
return pos_emb, scale
def forward(
self,
x,
mask = None,
finetune_modules = None
):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device, h = x.shape[1], x.device, self.heads
# pre layernorm
x = self.norm(x)
# attention queries, keys, values, and feedforward inner
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
# finetune loras
lora_q = lora_k = lora_v = lora_o = None
if exists(finetune_modules):
lora_q, lora_k, lora_v, lora_o = finetune_modules
q = q + lora_q(x)
k = k + lora_k(x)
v = v + lora_v(x)
# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# https://arxiv.org/abs/1911.02150
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# qk rmsnorm
if self.qk_rmsnorm:
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# rotary embeddings with xpos decay for better length extrapolation
positions, scale = self.get_rotary_embedding(n, device)
q = apply_rotary_pos_emb(positions, q, scale)
k = apply_rotary_pos_emb(positions, k, scale ** -1)
# similarity
sim = einsum("b h i d, b j d -> b h i j", q, k) * self.scale
# key padding mask
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# causal mask
if self.causal:
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# attention
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# aggregate values
out = einsum("b h i j, b j d -> b h i d", attn, v)
# merge heads
out = rearrange(out, "b h n d -> b n (h d)")
attn_out = self.attn_out(out)
ff_out = self.ff_out(ff)
if exists(lora_o):
attn_out = attn_out + lora_o(out)
return attn_out + ff_out
# transformer
@beartype
class PaLM(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
depth,
causal = True,
dim_head = 64,
heads = 8,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
qk_rmsnorm = False,
lora_r = 8,
rotary_xpos_scale_base = 512,
finetune_scopes = tuple(),
cross_entropy_ignore_index = 0
):
super().__init__()
self.dim = dim
self.dim_head = dim_head
self.heads = heads
self.causal = causal
self.num_tokens = num_tokens
self.token_emb = nn.Embedding(num_tokens, dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
block = Residual(ParallelTransformerBlock(
dim = dim,
causal = causal,
dim_head = dim_head,
heads = heads,
qk_rmsnorm = qk_rmsnorm,
ff_mult = ff_mult,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
xpos_scale_base = rotary_xpos_scale_base
))
self.layers.append(block)
self.norm = LayerNorm(dim)
self.to_logits = nn.Linear(dim, num_tokens, bias=False)
self.to_logits.weight = self.token_emb.weight
nn.init.normal_(self.token_emb.weight, std=0.02)
# fine tuning related
self.lora_r = lora_r
self.finetune_modules = nn.ModuleDict({})
for scope in finetune_scopes:
self.add_finetune_params(scope)
# loss related
self.cross_entropy_ignore_index = cross_entropy_ignore_index
@property
def device(self):
return next(self.parameters()).device
def load(self, path):
path = Path(path)
assert path.exists()
self.load_state_dict(torch.load(str(path)))
def set_dropout(self, dropout):
for module in self.layers.modules():
if isinstance(module, nn.Dropout):
module.p = dropout
return self
def add_finetune_params(self, scope, lora_r = None):
assert scope not in self.finetune_modules, f'finetune scope {scope} already found'
dim, dim_head, heads, r, device = self.dim, self.dim_head, self.heads, default(lora_r, self.lora_r), self.device
q_inner_dim = heads * dim_head
kv_inner_dim = dim_head
lora_modules = nn.ModuleList([])
for _ in range(len(self.layers)):
lora_modules.append(nn.ModuleList([
LoRA(dim, q_inner_dim, r = r), # queries
LoRA(dim, kv_inner_dim, r = r), # keys
LoRA(dim, kv_inner_dim, r = r), # values
LoRA(q_inner_dim, dim, r = r) # wo
]))
self.finetune_modules[scope] = lora_modules.to(device)
def remove_finetune_params(self, scope):
assert scope in self.finetune_modules, f'finetune scope {scope} not found'
return self.finetune_modules.pop(scope)
@torch.no_grad()
def merge_finetune_params(self, scope):
""" in the case one wants to merge the fine-tuned actor LORA parameters and do multiple rounds of fine tuning off different reward models """
assert scope in self.finetune_modules, f'finetune scope {scope} not found'
lora_modules = self.finetune_modules.pop(scope)
for layer, (lora_q, lora_k, lora_v, lora_o) in zip(self.layers, lora_modules):
block = layer.fn
fused_attn_ff_weight = block.fused_attn_ff_proj.weight
attn_out_weight = block.attn_out.weight
fused_proj_out_dim = fused_attn_ff_weight.shape[0]
lora_qkv_weight, _ = pack([lora_q.weight, lora_k.weight, lora_v.weight], 'i *')
lora_qkv_weight = F.pad(lora_qkv_weight, (0, fused_proj_out_dim - lora_qkv_weight.shape[1]))
lora_qkv_weight = rearrange(lora_qkv_weight, 'i o -> o i')
lora_o_weight = rearrange(lora_o.weight, 'i o -> o i')
fused_attn_ff_weight.add_(lora_qkv_weight)
attn_out_weight.add_(lora_o_weight)
# researcher train palm parameters first
# before finetuning
def palm_parameters(self):
return set(self.parameters()) - set(self.finetune_modules.parameters())
def finetune_parameters(self, scope = 'default'):
assert scope in self.finetune_modules, f'finetune parameters of scope {scope} not found'
return self.finetune_modules[scope].parameters()
# generate function
@torch.no_grad()
@eval_decorator
def generate(
self,
seq_len,
prompt = None,
temperature = 1.,
filter_logits_fn = top_k,
filter_thres = 0.9,
pad_value = 0.,
eos_token = None,
return_seq_without_prompt = True,
use_tqdm = False,
**kwargs
):
if not exists(prompt):
prompt = torch.randint(0, self.num_tokens, (1, 1))
prompt = prompt.to(self.device)
return_seq_without_prompt = False
prompt, leading_dims = pack([prompt], '* n')
n, out = prompt.shape[-1], prompt.clone()
wrapper_fn = identity if not use_tqdm else tqdm
sample_num_times = max(1, seq_len - prompt.shape[-1])
for _ in wrapper_fn(range(sample_num_times)):
logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs)
logits, embeds = logits[:, -1], embeds[:, -1]
if exists(filter_logits_fn):
logits = filter_logits_fn(logits, thres = filter_thres)
sample = gumbel_sample(logits, temperature = temperature, dim = -1)
out, _ = pack([out, sample], 'b *')
if exists(eos_token):
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, pad_value)
break
out, = unpack(out, leading_dims, '* n')
if not return_seq_without_prompt:
return out
return out[..., n:]
def forward(
self,
x,
return_loss = False,
disable_lora = False,
finetune_scope = None,
extra_embed = None,
return_only_embedding = False,
return_logits_with_embedding = False
):
if return_loss:
x, labels = x[:, :-1], x[:, 1:]
# mask if encoder
# treat any token ids that are negative as tokens to mask out - only needed if not autoregressive
if not self.causal:
mask = x >= 0
x = x.masked_fill(~mask, 0)
else:
mask = None
# get token embedding
x = self.token_emb(x)
if exists(extra_embed):
x = x + extra_embed
# finetune modules
finetune_modules = tuple()
if exists(finetune_scope) and not disable_lora:
assert finetune_scope in self.finetune_modules
finetune_modules = self.finetune_modules[finetune_scope]
# parallel attention / ff blocks, passing in finetuning loras
for layer, finetune_modules in zip_longest(self.layers, finetune_modules):
x = layer(x, mask = mask, finetune_modules = finetune_modules)
# final norm
embeds = self.norm(x)
if return_only_embedding:
return embeds
# to logits
logits = self.to_logits(x)
ret = (logits, embeds) if return_logits_with_embedding else logits
if not return_loss:
return ret
logits = rearrange(logits, 'b n c -> b c n')
return F.cross_entropy(logits, labels, ignore_index = self.cross_entropy_ignore_index)
此差异已折叠。
import copy
from pathlib import Path
from tqdm import tqdm
from beartype import beartype
from beartype.typing import Tuple, Optional
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from src.rlhf.utils import masked_mean, gumbel_sample
from src.model import RWKV
# helper functions
def exists(val):
return val is not None
# Reward Model - RWKV with a scalar head
@beartype
class RewardModel(nn.Module):
def __init__(
self,
rwkv: RWKV,
dropout = 0.1,
num_binned_output = 0.
):
super().__init__()
# 用预训练模型初始化奖励模型
self.rwkv = copy.deepcopy(rwkv)
self.rwkv.set_dropout(dropout) # todo(luxin)
# 输出 token 向量的维度
dim = rwkv.dim # todo(luxin)
# 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级
self.binned_output = num_binned_output > 1
# todo(luxin):prompt_embed 和 response_embed 都是初始化为全0?不应该有区分么
self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
# self.response_embed = nn.Parameter(torch.ones(1, 1, dim))
if self.binned_output:
# 如果打分等级的类别数大于1,则为多分类问题
self.to_pred = nn.Linear(dim, num_binned_output)
else:
# 否则,直接是一个二分类问题
self.to_pred = nn.Sequential(
nn.Linear(dim, 1, bias = False),
Rearrange('... 1 -> ...') # 降维
)
def load(self, path):
path = Path(path)
assert path.exists()
self.load_state_dict(torch.load(str(path)))
def finetune_parameters(self):
return [
*self.to_pred.parameters(),
*self.rwkv.parameters()
]
def forward(
self,
x,
mask = None,
prompt_mask = None,
prompt_lengths = None,
labels = None,
sample = False,
sample_temperature = 1.
):
# prompt_mask 和 prompt_lengths 只能给1个
assert not (exists(prompt_mask) and exists(prompt_lengths))
# derive prompt mask from prompt lengths
if exists(prompt_lengths):
batch, seq_len = x.shape
arange = torch.arange(seq_len, device = x.device)
prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')
# reward model should have an understanding of which section is prompt, and which section is response
# 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值
# 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选
extra_embed = None
if exists(prompt_mask):
extra_embed = torch.where(
rearrange(prompt_mask, 'b n -> b n 1'),
self.prompt_embed,
self.response_embed
)
# todo(luxin) get embeddings from rwkv
embeds = self.rwkv(
x,
extra_embed = extra_embed,
return_only_embedding = True
)
# 所有的 token 向量求平均,并输入到打分模块进行打分
pooled = masked_mean(embeds, mask, dim = 1)
pred = self.to_pred(pooled)
if sample and self.binned_output:
assert not exists(labels)
pred = gumbel_sample(pred, temperature = sample_temperature, dim = -1)
if not exists(labels):
return pred
# todo(luxin) 作者没有使用论文中考虑两个样本的 loss,而是单个样本的 loss
if not self.binned_output:
return F.mse_loss(pred, labels)
return F.cross_entropy(pred, labels)
import math
import torch
from torch import einsum, nn
import torch.nn.functional as F
from einops import rearrange
def exists(val):
return val is not None
# decorators
def eval_decorator(fn):
def inner(self, *args, **kwargs):
was_training = self.training
self.eval()
out = fn(self, *args, **kwargs)
self.train(was_training)
return out
return inner
# tensor helpers
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def masked_mean(seq, mask = None, dim = 1, keepdim = False):
if not exists(mask):
return seq.mean(dim = dim)
if seq.ndim == 3:
mask = rearrange(mask, 'b n -> b n 1')
masked_seq = seq.masked_fill(~mask, 0.)
numer = masked_seq.sum(dim = dim, keepdim = keepdim)
denom = mask.sum(dim = dim, keepdim = keepdim)
masked_mean = numer / denom.clamp(min = 1e-3)
masked_mean = masked_mean.masked_fill(denom == 0, 0.)
return masked_mean
# sampling helpers
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > (1 - thres)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
def top_k(logits, thres = 0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
import torch
from src.rlhf.reward import RewardModel
from src.model import RWKV
rwkv_model = RWKV()
reward_model = RewardModel(
rwkv_model,
num_binned_output = 5 # 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级
)
# mock data
seq = torch.randint(0, 20000, (1, 100))
# prompt_mask = torch.zeros(1, 100).bool() # which part of the sequence is prompt, which part is response
prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1)
labels = torch.randint(0, 5, (1,))
# train
loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
loss.backward()
# after much training
reward = reward_model(seq, prompt_mask = prompt_mask)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册