提交 51560034 编写于 作者: U u010280923

opt reward model

上级 46f70cba
......@@ -30,7 +30,7 @@ os.environ['RWKV_JIT_ON'] = '1'
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then compile CUDA kernel for seq mode (much faster)
# from rwkv.model import RWKV # pip install rwkv
from src.rlhf.rwkv import RWKV
from src.rlhf.rwkv.model import RWKV
# model = RWKV(model='./model/rwkv-190.pth', strategy='cpu fp32')
model = RWKV(model='./model/RWKV-4-Pile-169M-20220807-8023.pth', strategy='cpu fp32')
......@@ -46,7 +46,7 @@ model = RWKV(model='./model/RWKV-4-Pile-169M-20220807-8023.pth', strategy='cpu f
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230213-8019', strategy='cuda fp16 *0+ -> cpu fp32 *1')
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096', strategy='cuda:0 fp16 *25 -> cuda:1 fp16')
out, state = model.forward([187, 510, 1563, 310, 247], None)
out, state, token_embed = model.forward([187, 510, 1563, 310, 247], None)
print(out.detach().cpu().numpy()) # get logits
# out, state = model.forward([187, 510], None)
# out, state = model.forward([1563], state) # RNN has state (use deepcopy to clone states)
......@@ -58,7 +58,7 @@ ipdb.set_trace()
# print('\n')
# from src.utils import PIPELINE, PIPELINE_ARGS
# from src.rlhf.rwkv.utils import PIPELINE, PIPELINE_ARGS
# pipeline = PIPELINE(model, "20B_tokenizer.json")
# ctx = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
......
......@@ -429,7 +429,7 @@ class RWKV(pl.LightningModule):
return cfg.get("offload_optimizer") or cfg.get("offload_param")
return False
def forward(self, idx):
def forward(self, idx, extra_embed=None):
args = self.args
B, T = idx.size()
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
......@@ -437,6 +437,10 @@ class RWKV(pl.LightningModule):
x = self.emb(idx)
x_emb = x
# 给 x 加入额外的 embedding,例如在训练 RM 的时候,区分 prompt 和 response
if extra_embed is not None:
x_emb = x_emb + extra_embed
if args.tiny_att_dim > 0:
for block in self.blocks:
if args.grad_cp == 1:
......
......@@ -13,7 +13,7 @@ from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from src.rlhf.utils import masked_mean, gumbel_sample
from src.model import RWKV
from src.rlhf.rwkv.model import RWKV
# helper functions
......@@ -26,35 +26,25 @@ def exists(val):
class RewardModel(nn.Module):
def __init__(
self,
rwkv: RWKV,
dropout = 0.1,
num_binned_output = 0.
rwkv: RWKV
):
super().__init__()
# 用预训练模型初始化奖励模型
self.rwkv = copy.deepcopy(rwkv)
self.rwkv = rwkv
# 输出 token 向量的维度
dim = rwkv.dim # todo(luxin)
dim = rwkv.args.n_embd
# 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级
self.binned_output = num_binned_output > 1
# todo(luxin):prompt_embed 和 response_embed 都是初始化为全0?不应该有区分么
# 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0
self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
# self.response_embed = nn.Parameter(torch.ones(1, 1, dim))
if self.binned_output:
# 如果打分等级的类别数大于1,则为多分类问题
self.to_pred = nn.Linear(dim, num_binned_output)
else:
# 否则,直接是一个二分类问题
self.to_pred = nn.Sequential(
nn.Linear(dim, 1, bias = False),
Rearrange('... 1 -> ...') # 降维
)
# reward 得分计算
self.pred_reward = nn.Sequential(
nn.Linear(dim, 1),
Rearrange('... 1 -> ...') # 降维
)
def load(self, path):
path = Path(path)
......@@ -72,13 +62,10 @@ class RewardModel(nn.Module):
x,
mask = None,
prompt_mask = None,
prompt_lengths = None,
labels = None,
sample = False,
sample_temperature = 1.
prompt_lengths = None
):
# prompt_mask 和 prompt_lengths 只能给1个
# prompt_mask 和 prompt_lengths 只能二选一
assert not (exists(prompt_mask) and exists(prompt_lengths))
# derive prompt mask from prompt lengths
......@@ -98,26 +85,19 @@ class RewardModel(nn.Module):
self.response_embed
)
# todo(luxin) get embeddings from rwkv
embeds = self.rwkv(
# 获得最后一个 token 的 embedding
last_token_embeds = self.rwkv(
x,
extra_embed = extra_embed,
return_only_embedding = True
state=None,
extra_embed=extra_embed
)
# 所有的 token 向量求平均,并输入到打分模块进行打分
pooled = masked_mean(embeds, mask, dim = 1)
pred = self.to_pred(pooled)
if sample and self.binned_output:
assert not exists(labels)
pred = gumbel_sample(pred, temperature = sample_temperature, dim = -1)
if not exists(labels):
return pred
# todo(luxin) 作者没有使用论文中考虑两个样本的 loss,而是单个样本的 loss
if not self.binned_output:
return F.mse_loss(pred, labels)
return F.cross_entropy(pred, labels)
try:
pooled = masked_mean(last_token_embeds, mask, dim = 1)
except:
import ipdb
ipdb.set_trace()
reward = self.pred_reward(pooled)
return reward
......@@ -4,6 +4,7 @@
import types, gc, os, time
import torch
import copy
from torch.nn import functional as F
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
......@@ -364,12 +365,13 @@ class RWKV(MyModule):
########################################################################################################
def forward(self, tokens, state, full_output=False):
def forward(self, tokens, state, full_output=False, extra_embed=None):
with torch.no_grad():
w = self.w
args = self.args
if state == None:
# 初始化 state
state = [None] * args.n_layer * 5
for i in range(args.n_layer): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
dd = self.strategy[i]
......@@ -382,10 +384,15 @@ class RWKV(MyModule):
state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev)
seq_mode = len(tokens) > 1
import ipdb
ipdb.set_trace()
# 输入:根据 idx 取每个 token 的 embedding
x = w['emb.weight'][tokens if seq_mode else tokens[0]]
if extra_embed is not None:
x = x + extra_embed
# 推理:N 层的 Block(Attention, Feed Forward)
for i in range(args.n_layer):
bbb = f'blocks.{i}.'
att = f'blocks.{i}.att.'
......@@ -404,8 +411,10 @@ class RWKV(MyModule):
ATT = self.att_one
FFN = self.ffn_one
# Tensor dtype and/or device 类型转换
x = x.to(dtype=atype, device=dev)
# Attention 层
kw = self.get_w(f'{att}key.weight', atype)
vw = self.get_w(f'{att}value.weight', atype)
rw = self.get_w(f'{att}receptance.weight', atype)
......@@ -424,6 +433,7 @@ class RWKV(MyModule):
if wtype == torch.uint8 or dd.stream:
del kw, vw, rw, ow
# Feed Forward 层
kw = self.get_w(f'{ffn}key.weight', atype)
vw = self.get_w(f'{ffn}value.weight', atype)
rw = self.get_w(f'{ffn}receptance.weight', atype)
......@@ -443,15 +453,18 @@ class RWKV(MyModule):
if (i+1) % self.RESCALE_LAYER == 0:
x = x / 2
# 取所有 token 还是最后一个 token 的 embedding
dd = self.strategy[args.n_layer]
x = x[-1,:] if (seq_mode and (not full_output)) else x
x = x.to(dtype=dd.atype, device=dd.device)
# 对 token embedding 进行 LayerNorm,维度不变
x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias'])
token_embed = copy.deepcopy(x)
if w['head.weight'].dtype != torch.uint8:
x = x @ w['head.weight']
else:
x = x @ self.get_w('head.weight', dd.atype)
return x.float(), state
return x.float(), state, token_embed.float()
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import json, time, random, os
import numpy as np
import torch
from torch.nn import functional as F
from tokenizers import Tokenizer
class PIPELINE_ARGS():
def __init__(self, temperature=1.0, top_p=0.85, top_k=0, alpha_frequency=0.2, alpha_presence=0.2, token_ban=[], token_stop=[]):
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
self.token_ban = token_ban # ban the generation of some tokens
self.token_stop = token_stop # stop generation whenever you see any token here
class PIPELINE():
def __init__(self, model, WORD_NAME):
self.model = model
self.tokenizer = Tokenizer.from_file(WORD_NAME)
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def encode(self, x):
return self.tokenizer.encode(x).ids
def decode(self, x):
return self.tokenizer.decode(x)
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
if probs.device == torch.device('cpu'):
probs = probs.numpy()
sorted_ids = np.argsort(probs)
sorted_probs = probs[sorted_ids][::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return int(out)
else:
sorted_ids = torch.argsort(probs)
sorted_probs = probs[sorted_ids]
sorted_probs = torch.flip(sorted_probs, dims=(0,))
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return int(out)
def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None):
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
for i in range(token_count):
# forward & adjust prob.
out, state = self.model.forward(self.encode(ctx) if i == 0 else [token], state)
for n in args.token_ban:
out[n] = -float('inf')
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
# sampler
token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
if token in args.token_stop:
break
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
# output
tmp = self.decode(all_tokens[out_last:])
if '\ufffd' not in tmp: # is valid utf-8 string?
if callback:
callback(tmp)
out_str += tmp
out_last = i + 1
return out_str
......@@ -6,27 +6,259 @@
'''
# here put the import lib
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
if __name__ == "__main__":
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
rank_zero_info("########## work in progress ##########")
########################################################################################################
#
# example: train a simple L12-D768 RWKV on dummy data
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "" --data_type "dummy" --vocab_size 0 \
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: train a simple L6-D512 RWKV from scratch on enwik8
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser = ArgumentParser()
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int)
parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--dim_att", default=0, type=int)
parser.add_argument("--dim_ffn", default=0, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int)
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser.add_argument("--my_img_version", default=0, type=str)
parser.add_argument("--my_img_size", default=0, type=int)
parser.add_argument("--my_img_bit", default=0, type=int)
parser.add_argument("--my_img_clip", default='x', type=str)
parser.add_argument("--my_img_clip_scale", default=1, type=float)
parser.add_argument("--my_img_l1_scale", default=0, type=float)
parser.add_argument("--my_img_encoder", default='x', type=str)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser.add_argument("--my_sample_len", default=0, type=int)
parser.add_argument("--my_ffn_shift", default=1, type=int)
parser.add_argument("--my_att_shift", default=1, type=int)
parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
########################################################################################################
import os, warnings, math, datetime, sys, time
import numpy as np
import torch
from torch.utils.data import DataLoader
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
# os.environ["WDS_SHOW_SEED"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False
args.replace_sampler_ddp = False
args.logger = False
args.gradient_clip_val = 1.0
args.num_sanity_val_steps = 0
args.check_val_every_n_epoch = int(1e20)
args.log_every_n_steps = int(1e20)
args.max_epochs = -1 # continue forever
args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
os.environ["RWKV_MY_TESTING"] = args.my_testing
if args.dim_att <= 0:
args.dim_att = args.n_embd
if args.dim_ffn <= 0:
args.dim_ffn = args.n_embd * 4
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir)
samples_per_epoch = args.epoch_steps * args.real_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len
rank_zero_info(
f"""
############################################################################
#
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
#
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
#
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
#
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
#
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
#
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
#
############################################################################
"""
)
rank_zero_info(str(vars(args)) + "\n")
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32":
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
os.environ["RWKV_JIT_ON"] = "1"
if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT_ON"] = "0"
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
if args.precision == "fp32":
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
else:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if "32" in args.precision:
args.precision = 32
elif args.precision == "fp16":
args.precision = 16
else:
args.precision = "bf16"
########################################################################################################
import torch
from src.rlhf.reward import RewardModel
from src.model import RWKV
from src.rlhf.rwkv.model import RWKV
rwkv_model = RWKV()
model = "./model/RWKV-4-Pile-169M-20220807-8023.pth"
strategy = "cpu fp32"
rwkv_model = RWKV(model, strategy)
dim = rwkv_model.args.n_embd
reward_model = RewardModel(
rwkv_model,
num_binned_output = 5 # 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级
rwkv_model
)
# mock data
seq = torch.randint(0, 20000, (1, 100))
# prompt_mask = torch.zeros(1, 100).bool() # which part of the sequence is prompt, which part is response
prompt = torch.randint(0, dim, (1, 50))
prefer_response = torch.randint(0, dim, (1, 50))
alter_response = torch.randint(0, dim, (1, 50))
prefer_pair = torch.concat((prompt, prefer_response), dim=1)
alter_pair = torch.concat((prompt, alter_response), dim=1)
# which part of the sequence is prompt, which part is response
prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1)
labels = torch.randint(0, 5, (1,))
# labels = torch.randint(0, 5, (1,))
# train
loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
loss.backward()
# loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
# loss.backward()
# inference
prefer_reward = reward_model(prefer_pair, prompt_mask = prompt_mask)
alter_reward = reward_model(alter_pair, prompt_mask = prompt_mask)
# after much training
reward = reward_model(seq, prompt_mask = prompt_mask)
\ No newline at end of file
print("Preferred response reward:", prefer_reward)
print("Alternate response reward:", alter_reward)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册