提交 3c80d013 编写于 作者: U u010280923

opt reward model

上级 e45f1cdf
...@@ -63,6 +63,16 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" ...@@ -63,6 +63,16 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft"
### Reward Model ### Reward Model
```
python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
--my_qa_mask 1
```
### 接入RLHF(Reinforcement Learning with Human Feedback) ### 接入RLHF(Reinforcement Learning with Human Feedback)
......
...@@ -257,3 +257,45 @@ class S2SDataset(Dataset): ...@@ -257,3 +257,45 @@ class S2SDataset(Dataset):
z = torch.tensor(z, dtype=torch.long) z = torch.tensor(z, dtype=torch.long)
return x, y, z return x, y, z
class RMDataset(Dataset):
def __init__(self, args):
self.args = args
self.vocab_size = args.vocab_size
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
self.tokenizer = TOKENIZER(WORD_NAME)
pf = pd.read_csv(args.data_file)
data_list = []
for index, row in pf.iterrows():
prompt = row["prompt"]
preferred = row["preferred"]
alternate = row["alternate"]
preferred_sample = f"{prompt}\n{preferred}"
alternate_sample = f"{prompt}\n{alternate}"
data_list.append((self.tokenizer.tokenizer.encode(preferred_sample),
self.tokenizer.tokenizer.encode(alternate_sample)))
self.data = data_list
def __len__(self):
return len(self.data)
def __getitem__(self, index):
ctx_len = self.args.ctx_len
req_len = ctx_len + 1
preferred_sample, alternate_sample = self.data[index]
preferred_sample = preferred_sample[: req_len]
alternate_sample = alternate_sample[: req_len]
preferred_sample = preferred_sample + [0] * (req_len - len(preferred_sample))
alternate_sample = alternate_sample + [0] * (req_len - len(alternate_sample))
x_p = torch.tensor(preferred_sample, dtype=torch.long)
x_a = torch.tensor(alternate_sample, dtype=torch.long)
return x_p, x_a
\ No newline at end of file
...@@ -429,7 +429,7 @@ class RWKV(pl.LightningModule): ...@@ -429,7 +429,7 @@ class RWKV(pl.LightningModule):
return cfg.get("offload_optimizer") or cfg.get("offload_param") return cfg.get("offload_optimizer") or cfg.get("offload_param")
return False return False
def forward(self, idx, extra_embed=None): def forward(self, idx, extra_embed=None, rm_train=False):
args = self.args args = self.args
B, T = idx.size() B, T = idx.size()
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
...@@ -456,6 +456,10 @@ class RWKV(pl.LightningModule): ...@@ -456,6 +456,10 @@ class RWKV(pl.LightningModule):
x = self.ln_out(x) x = self.ln_out(x)
# 用于 RM 模型的编码
if rm_train is True:
return x
if args.head_qk > 0: if args.head_qk > 0:
q = self.head_q(x)[:, :T, :] q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :] k = self.head_k(x)[:, :T, :]
......
...@@ -8,33 +8,63 @@ from beartype.typing import Tuple, Optional ...@@ -8,33 +8,63 @@ from beartype.typing import Tuple, Optional
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info
from einops import rearrange, repeat, reduce, pack, unpack from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce from einops.layers.torch import Rearrange, Reduce
from src.rlhf.utils import masked_mean, gumbel_sample from src.rlhf.utils import masked_mean, gumbel_sample
from src.rlhf.rwkv.model import RWKV # from src.model import RWKV
from src.model import RWKV
# helper functions # helper functions
def exists(val): def exists(val):
return val is not None return val is not None
# loss function
def loss_function(prefer_reward, alter_reward):
return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward)))
# Reward Model - RWKV with a scalar head # Reward Model - RWKV with a scalar head
@beartype @beartype
class RewardModel(nn.Module): class RewardModel(pl.LightningModule):
def __init__( def __init__(
self, self,
args,
rwkv: RWKV rwkv: RWKV
): ):
super().__init__() super().__init__()
# 加载 RWKV 模型
rwkv = RWKV(args)
if len(args.load_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
rank_zero_info(f"########## Loading {args.load_model}... ##########")
try:
load_dict = torch.load(args.load_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
# 用预训练模型初始化奖励模型 # 用预训练模型初始化奖励模型
self.rwkv = rwkv self.rwkv = rwkv
self.args = args
# 输出 token 向量的维度 # 输出 token 向量的维度
dim = rwkv.args.n_embd dim = self.args.n_embd
# 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0 # 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0
self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)) self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
...@@ -56,8 +86,18 @@ class RewardModel(nn.Module): ...@@ -56,8 +86,18 @@ class RewardModel(nn.Module):
*self.to_pred.parameters(), *self.to_pred.parameters(),
*self.rwkv.parameters() *self.rwkv.parameters()
] ]
def forward( def configure_optimizers(self):
# 论文中的参数:lr=1e-5, betas=(0.9, 0.95)
optimizer = torch.optim.Adam([
{"rwkv_params": self.rwkv.parameters()},
{"rm_params": self.parameters()}
], lr=self.args.lr_init, betas=self.args.betas)
return optimizer
def single_forward(
self, self,
x, x,
mask = None, mask = None,
...@@ -89,15 +129,29 @@ class RewardModel(nn.Module): ...@@ -89,15 +129,29 @@ class RewardModel(nn.Module):
last_token_embeds = self.rwkv( last_token_embeds = self.rwkv(
x, x,
state=None, state=None,
extra_embed=extra_embed extra_embed=extra_embed,
rm_train=True
) )
# 所有的 token 向量求平均,并输入到打分模块进行打分 # 所有的 token 向量求平均,并输入到打分模块进行打分
try: pooled = masked_mean(last_token_embeds, mask, dim = 1)
pooled = masked_mean(last_token_embeds, mask, dim = 1)
except:
import ipdb
ipdb.set_trace()
reward = self.pred_reward(pooled) reward = self.pred_reward(pooled)
return reward return reward
def forward(self, prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask):
prefer_reward = self.single_forward(prefer_x, prefer_x_prompt_mask)
alter_reward = self.single_forward(alter_x, alter_x_prompt_mask)
return prefer_reward, alter_reward
def training_step(self, batch, batch_idx):
prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask = batch
prefer_reward, alter_reward = self(
prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask)
loss = loss_function(prefer_reward, alter_reward)
return loss
...@@ -365,7 +365,7 @@ class RWKV(MyModule): ...@@ -365,7 +365,7 @@ class RWKV(MyModule):
######################################################################################################## ########################################################################################################
def forward(self, tokens, state, full_output=False, extra_embed=None): def forward(self, tokens, state, full_output=False, extra_embed=None, rm_train=False):
with torch.no_grad(): with torch.no_grad():
w = self.w w = self.w
args = self.args args = self.args
...@@ -384,8 +384,6 @@ class RWKV(MyModule): ...@@ -384,8 +384,6 @@ class RWKV(MyModule):
state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev) state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev)
seq_mode = len(tokens) > 1 seq_mode = len(tokens) > 1
import ipdb
ipdb.set_trace()
# 输入:根据 idx 取每个 token 的 embedding # 输入:根据 idx 取每个 token 的 embedding
x = w['emb.weight'][tokens if seq_mode else tokens[0]] x = w['emb.weight'][tokens if seq_mode else tokens[0]]
...@@ -460,11 +458,12 @@ class RWKV(MyModule): ...@@ -460,11 +458,12 @@ class RWKV(MyModule):
# 对 token embedding 进行 LayerNorm,维度不变 # 对 token embedding 进行 LayerNorm,维度不变
x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias']) x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias'])
token_embed = copy.deepcopy(x) if rm_train is True:
return x
if w['head.weight'].dtype != torch.uint8: if w['head.weight'].dtype != torch.uint8:
x = x @ w['head.weight'] x = x @ w['head.weight']
else: else:
x = x @ self.get_w('head.weight', dd.atype) x = x @ self.get_w('head.weight', dd.atype)
return x.float(), state, token_embed.float() return x.float(), state
...@@ -220,45 +220,44 @@ if __name__ == "__main__": ...@@ -220,45 +220,44 @@ if __name__ == "__main__":
######################################################################################################## ########################################################################################################
# 训练 RM 模型
def loss_function(prefer_reward, alter_reward):
return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward)))
import torch
from tqdm import tqdm
from src.trainer import train_callback
from src.rlhf.reward import RewardModel
from src.model import RWKV
from src.dataset import RMDataset
# 读入训练数据
train_data = RMDataset(args)
args.vocab_size = train_data.vocab_size
# RM 模型
rm_model = RewardModel(args)
import torch # 训练
trainer = Trainer.from_argparse_args()
from src.rlhf.reward import RewardModel
from src.rlhf.rwkv.model import RWKV
model = "./model/RWKV-4-Pile-169M-20220807-8023.pth"
strategy = "cpu fp32"
rwkv_model = RWKV(model, strategy)
dim = rwkv_model.args.n_embd
reward_model = RewardModel(
rwkv_model
)
# mock data
prompt = torch.randint(0, dim, (1, 50))
prefer_response = torch.randint(0, dim, (1, 50))
alter_response = torch.randint(0, dim, (1, 50))
prefer_pair = torch.concat((prompt, prefer_response), dim=1) if trainer.global_rank == 0:
alter_pair = torch.concat((prompt, alter_response), dim=1) for n in rm_model.state_dict():
shape = rm_model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
# which part of the sequence is prompt, which part is response if "deepspeed" in args.strategy:
prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1) trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
# labels = torch.randint(0, 5, (1,)) trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
# train # must set shuffle=True, persistent_workers=False (because worker is in another thread)
# loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels) data_loader = DataLoader(train_data, shuffle=True, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
# loss.backward()
# inference trainer.fit(rm_model, data_loader)
prefer_reward = reward_model(prefer_pair, prompt_mask = prompt_mask)
alter_reward = reward_model(alter_pair, prompt_mask = prompt_mask)
print("Preferred response reward:", prefer_reward)
print("Alternate response reward:", alter_reward)
\ No newline at end of file
'''
@File : train_rm_demo.py
@Time : 2023/03/10 00:54:57
@Author : Lu Xin
@Contact : luxin@csdn.net
'''
# here put the import lib
import torch
from tqdm import tqdm
from src.rlhf.reward import RewardModel
from src.rlhf.rwkv.model import RWKV
def loss_function(prefer_reward, alter_reward):
return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward)))
model = "./model/RWKV-4-Pile-169M-20220807-8023.pth"
strategy = "cpu fp32"
rwkv_model = RWKV(model, strategy)
reward_model = RewardModel(
rwkv_model
)
import ipdb
ipdb.set_trace()
# as used in the InstructGPT paper
optimizer = torch.optim.Adam(
reward_model.parameters(), lr=1e-5, betas=(0.9, 0.95))
# 假数据
dim = 20000
prompt = torch.randint(0, dim, (1, 50))
prefer_response = torch.randint(0, dim, (1, 50))
alter_response = torch.randint(0, dim, (1, 50))
prefer_pair = torch.concat((prompt, prefer_response), dim=1)
alter_pair = torch.concat((prompt, alter_response), dim=1)
prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1)
for epoch in range(100):
# 计算奖励
prefer_reward = reward_model(prefer_pair, prompt_mask = prompt_mask)
alter_reward = reward_model(alter_pair, prompt_mask = prompt_mask)
# print(f"prefer_reward: {prefer_reward}")
# print(f"alter_reward: {alter_reward}")
# train
loss = loss_function(prefer_reward, alter_reward)
print(f"loss: {loss}")
# Backward pass
loss.backward()
optimizer.step()
# Zero the gradients
optimizer.zero_grad()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册