提交 3c80d013 编写于 作者: U u010280923

opt reward model

上级 e45f1cdf
......@@ -63,6 +63,16 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft"
### Reward Model
```
python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
--my_qa_mask 1
```
### 接入RLHF(Reinforcement Learning with Human Feedback)
......
......@@ -257,3 +257,45 @@ class S2SDataset(Dataset):
z = torch.tensor(z, dtype=torch.long)
return x, y, z
class RMDataset(Dataset):
def __init__(self, args):
self.args = args
self.vocab_size = args.vocab_size
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
self.tokenizer = TOKENIZER(WORD_NAME)
pf = pd.read_csv(args.data_file)
data_list = []
for index, row in pf.iterrows():
prompt = row["prompt"]
preferred = row["preferred"]
alternate = row["alternate"]
preferred_sample = f"{prompt}\n{preferred}"
alternate_sample = f"{prompt}\n{alternate}"
data_list.append((self.tokenizer.tokenizer.encode(preferred_sample),
self.tokenizer.tokenizer.encode(alternate_sample)))
self.data = data_list
def __len__(self):
return len(self.data)
def __getitem__(self, index):
ctx_len = self.args.ctx_len
req_len = ctx_len + 1
preferred_sample, alternate_sample = self.data[index]
preferred_sample = preferred_sample[: req_len]
alternate_sample = alternate_sample[: req_len]
preferred_sample = preferred_sample + [0] * (req_len - len(preferred_sample))
alternate_sample = alternate_sample + [0] * (req_len - len(alternate_sample))
x_p = torch.tensor(preferred_sample, dtype=torch.long)
x_a = torch.tensor(alternate_sample, dtype=torch.long)
return x_p, x_a
\ No newline at end of file
......@@ -429,7 +429,7 @@ class RWKV(pl.LightningModule):
return cfg.get("offload_optimizer") or cfg.get("offload_param")
return False
def forward(self, idx, extra_embed=None):
def forward(self, idx, extra_embed=None, rm_train=False):
args = self.args
B, T = idx.size()
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
......@@ -456,6 +456,10 @@ class RWKV(pl.LightningModule):
x = self.ln_out(x)
# 用于 RM 模型的编码
if rm_train is True:
return x
if args.head_qk > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
......
......@@ -8,33 +8,63 @@ from beartype.typing import Tuple, Optional
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from src.rlhf.utils import masked_mean, gumbel_sample
from src.rlhf.rwkv.model import RWKV
# from src.model import RWKV
from src.model import RWKV
# helper functions
def exists(val):
return val is not None
# loss function
def loss_function(prefer_reward, alter_reward):
return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward)))
# Reward Model - RWKV with a scalar head
@beartype
class RewardModel(nn.Module):
class RewardModel(pl.LightningModule):
def __init__(
self,
args,
rwkv: RWKV
):
super().__init__()
# 加载 RWKV 模型
rwkv = RWKV(args)
if len(args.load_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
rank_zero_info(f"########## Loading {args.load_model}... ##########")
try:
load_dict = torch.load(args.load_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
# 用预训练模型初始化奖励模型
self.rwkv = rwkv
self.args = args
# 输出 token 向量的维度
dim = rwkv.args.n_embd
dim = self.args.n_embd
# 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0
self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
......@@ -56,8 +86,18 @@ class RewardModel(nn.Module):
*self.to_pred.parameters(),
*self.rwkv.parameters()
]
def forward(
def configure_optimizers(self):
# 论文中的参数:lr=1e-5, betas=(0.9, 0.95)
optimizer = torch.optim.Adam([
{"rwkv_params": self.rwkv.parameters()},
{"rm_params": self.parameters()}
], lr=self.args.lr_init, betas=self.args.betas)
return optimizer
def single_forward(
self,
x,
mask = None,
......@@ -89,15 +129,29 @@ class RewardModel(nn.Module):
last_token_embeds = self.rwkv(
x,
state=None,
extra_embed=extra_embed
extra_embed=extra_embed,
rm_train=True
)
# 所有的 token 向量求平均,并输入到打分模块进行打分
try:
pooled = masked_mean(last_token_embeds, mask, dim = 1)
except:
import ipdb
ipdb.set_trace()
pooled = masked_mean(last_token_embeds, mask, dim = 1)
reward = self.pred_reward(pooled)
return reward
def forward(self, prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask):
prefer_reward = self.single_forward(prefer_x, prefer_x_prompt_mask)
alter_reward = self.single_forward(alter_x, alter_x_prompt_mask)
return prefer_reward, alter_reward
def training_step(self, batch, batch_idx):
prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask = batch
prefer_reward, alter_reward = self(
prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask)
loss = loss_function(prefer_reward, alter_reward)
return loss
......@@ -365,7 +365,7 @@ class RWKV(MyModule):
########################################################################################################
def forward(self, tokens, state, full_output=False, extra_embed=None):
def forward(self, tokens, state, full_output=False, extra_embed=None, rm_train=False):
with torch.no_grad():
w = self.w
args = self.args
......@@ -384,8 +384,6 @@ class RWKV(MyModule):
state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev)
seq_mode = len(tokens) > 1
import ipdb
ipdb.set_trace()
# 输入:根据 idx 取每个 token 的 embedding
x = w['emb.weight'][tokens if seq_mode else tokens[0]]
......@@ -460,11 +458,12 @@ class RWKV(MyModule):
# 对 token embedding 进行 LayerNorm,维度不变
x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias'])
token_embed = copy.deepcopy(x)
if rm_train is True:
return x
if w['head.weight'].dtype != torch.uint8:
x = x @ w['head.weight']
else:
x = x @ self.get_w('head.weight', dd.atype)
return x.float(), state, token_embed.float()
return x.float(), state
......@@ -220,45 +220,44 @@ if __name__ == "__main__":
########################################################################################################
# 训练 RM 模型
def loss_function(prefer_reward, alter_reward):
return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward)))
import torch
from tqdm import tqdm
from src.trainer import train_callback
from src.rlhf.reward import RewardModel
from src.model import RWKV
from src.dataset import RMDataset
# 读入训练数据
train_data = RMDataset(args)
args.vocab_size = train_data.vocab_size
# RM 模型
rm_model = RewardModel(args)
import torch
from src.rlhf.reward import RewardModel
from src.rlhf.rwkv.model import RWKV
model = "./model/RWKV-4-Pile-169M-20220807-8023.pth"
strategy = "cpu fp32"
rwkv_model = RWKV(model, strategy)
dim = rwkv_model.args.n_embd
reward_model = RewardModel(
rwkv_model
)
# mock data
prompt = torch.randint(0, dim, (1, 50))
prefer_response = torch.randint(0, dim, (1, 50))
alter_response = torch.randint(0, dim, (1, 50))
# 训练
trainer = Trainer.from_argparse_args()
prefer_pair = torch.concat((prompt, prefer_response), dim=1)
alter_pair = torch.concat((prompt, alter_response), dim=1)
if trainer.global_rank == 0:
for n in rm_model.state_dict():
shape = rm_model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
# which part of the sequence is prompt, which part is response
prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1)
# labels = torch.randint(0, 5, (1,))
if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
# train
# loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
# loss.backward()
# must set shuffle=True, persistent_workers=False (because worker is in another thread)
data_loader = DataLoader(train_data, shuffle=True, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
# inference
prefer_reward = reward_model(prefer_pair, prompt_mask = prompt_mask)
alter_reward = reward_model(alter_pair, prompt_mask = prompt_mask)
trainer.fit(rm_model, data_loader)
print("Preferred response reward:", prefer_reward)
print("Alternate response reward:", alter_reward)
\ No newline at end of file
'''
@File : train_rm_demo.py
@Time : 2023/03/10 00:54:57
@Author : Lu Xin
@Contact : luxin@csdn.net
'''
# here put the import lib
import torch
from tqdm import tqdm
from src.rlhf.reward import RewardModel
from src.rlhf.rwkv.model import RWKV
def loss_function(prefer_reward, alter_reward):
return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward)))
model = "./model/RWKV-4-Pile-169M-20220807-8023.pth"
strategy = "cpu fp32"
rwkv_model = RWKV(model, strategy)
reward_model = RewardModel(
rwkv_model
)
import ipdb
ipdb.set_trace()
# as used in the InstructGPT paper
optimizer = torch.optim.Adam(
reward_model.parameters(), lr=1e-5, betas=(0.9, 0.95))
# 假数据
dim = 20000
prompt = torch.randint(0, dim, (1, 50))
prefer_response = torch.randint(0, dim, (1, 50))
alter_response = torch.randint(0, dim, (1, 50))
prefer_pair = torch.concat((prompt, prefer_response), dim=1)
alter_pair = torch.concat((prompt, alter_response), dim=1)
prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1)
for epoch in range(100):
# 计算奖励
prefer_reward = reward_model(prefer_pair, prompt_mask = prompt_mask)
alter_reward = reward_model(alter_pair, prompt_mask = prompt_mask)
# print(f"prefer_reward: {prefer_reward}")
# print(f"alter_reward: {alter_reward}")
# train
loss = loss_function(prefer_reward, alter_reward)
print(f"loss: {loss}")
# Backward pass
loss.backward()
optimizer.step()
# Zero the gradients
optimizer.zero_grad()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册