From 3c80d0139664aaeb20a5acaf6b63fe477de46d9e Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 11:49:47 +0800 Subject: [PATCH] opt reward model --- README.md | 10 ++++++ src/dataset.py | 42 +++++++++++++++++++++++ src/model.py | 6 +++- src/rlhf/reward.py | 76 ++++++++++++++++++++++++++++++++++++------ src/rlhf/rwkv/model.py | 9 +++-- train_rm.py | 61 +++++++++++++++++---------------- train_rm_demo.py | 62 ++++++++++++++++++++++++++++++++++ 7 files changed, 218 insertions(+), 48 deletions(-) create mode 100644 train_rm_demo.py diff --git a/README.md b/README.md index e8b3018..f37df80 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,16 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" ### Reward Model +``` +python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" \ +--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \ +--ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ +--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ +--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ +--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \ +--my_qa_mask 1 +``` + ### 接入RLHF(Reinforcement Learning with Human Feedback) diff --git a/src/dataset.py b/src/dataset.py index 8e0160b..68b6807 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -257,3 +257,45 @@ class S2SDataset(Dataset): z = torch.tensor(z, dtype=torch.long) return x, y, z + + +class RMDataset(Dataset): + def __init__(self, args): + self.args = args + self.vocab_size = args.vocab_size + WORD_NAME = [ + "20B_tokenizer.json", + "20B_tokenizer.json", + ] # [vocab, vocab] for Pile model + + self.tokenizer = TOKENIZER(WORD_NAME) + pf = pd.read_csv(args.data_file) + data_list = [] + for index, row in pf.iterrows(): + prompt = row["prompt"] + preferred = row["preferred"] + alternate = row["alternate"] + preferred_sample = f"{prompt}\n{preferred}" + alternate_sample = f"{prompt}\n{alternate}" + data_list.append((self.tokenizer.tokenizer.encode(preferred_sample), + self.tokenizer.tokenizer.encode(alternate_sample))) + self.data = data_list + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + ctx_len = self.args.ctx_len + req_len = ctx_len + 1 + preferred_sample, alternate_sample = self.data[index] + + preferred_sample = preferred_sample[: req_len] + alternate_sample = alternate_sample[: req_len] + + preferred_sample = preferred_sample + [0] * (req_len - len(preferred_sample)) + alternate_sample = alternate_sample + [0] * (req_len - len(alternate_sample)) + + x_p = torch.tensor(preferred_sample, dtype=torch.long) + x_a = torch.tensor(alternate_sample, dtype=torch.long) + + return x_p, x_a \ No newline at end of file diff --git a/src/model.py b/src/model.py index 8097cdf..9a785ce 100644 --- a/src/model.py +++ b/src/model.py @@ -429,7 +429,7 @@ class RWKV(pl.LightningModule): return cfg.get("offload_optimizer") or cfg.get("offload_param") return False - def forward(self, idx, extra_embed=None): + def forward(self, idx, extra_embed=None, rm_train=False): args = self.args B, T = idx.size() assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." @@ -456,6 +456,10 @@ class RWKV(pl.LightningModule): x = self.ln_out(x) + # 用于 RM 模型的编码 + if rm_train is True: + return x + if args.head_qk > 0: q = self.head_q(x)[:, :T, :] k = self.head_k(x)[:, :T, :] diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 2b68cd3..1b60359 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -8,33 +8,63 @@ from beartype.typing import Tuple, Optional import torch from torch import nn import torch.nn.functional as F +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_info from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange, Reduce from src.rlhf.utils import masked_mean, gumbel_sample -from src.rlhf.rwkv.model import RWKV +# from src.model import RWKV +from src.model import RWKV # helper functions def exists(val): return val is not None +# loss function +def loss_function(prefer_reward, alter_reward): + return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward))) + # Reward Model - RWKV with a scalar head @beartype -class RewardModel(nn.Module): +class RewardModel(pl.LightningModule): def __init__( self, + args, rwkv: RWKV ): super().__init__() + # 加载 RWKV 模型 + rwkv = RWKV(args) + + if len(args.load_model) == 0: + rank_zero_info(f"SFT must load model, please input ") + exit(1) + + rank_zero_info(f"########## Loading {args.load_model}... ##########") + try: + load_dict = torch.load(args.load_model, map_location="cpu") + except: + rank_zero_info(f"Bad checkpoint {args.load_model}") + exit(1) + + if args.load_partial == 1: + load_keys = load_dict.keys() + for k in rwkv.state_dict(): + if k not in load_keys: + load_dict[k] = rwkv.state_dict()[k] + rwkv.load_state_dict(load_dict) + # 用预训练模型初始化奖励模型 self.rwkv = rwkv + self.args = args # 输出 token 向量的维度 - dim = rwkv.args.n_embd + dim = self.args.n_embd # 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0 self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)) @@ -56,8 +86,18 @@ class RewardModel(nn.Module): *self.to_pred.parameters(), *self.rwkv.parameters() ] - - def forward( + + def configure_optimizers(self): + # 论文中的参数:lr=1e-5, betas=(0.9, 0.95) + optimizer = torch.optim.Adam([ + {"rwkv_params": self.rwkv.parameters()}, + {"rm_params": self.parameters()} + ], lr=self.args.lr_init, betas=self.args.betas) + + + return optimizer + + def single_forward( self, x, mask = None, @@ -89,15 +129,29 @@ class RewardModel(nn.Module): last_token_embeds = self.rwkv( x, state=None, - extra_embed=extra_embed + extra_embed=extra_embed, + rm_train=True ) # 所有的 token 向量求平均,并输入到打分模块进行打分 - try: - pooled = masked_mean(last_token_embeds, mask, dim = 1) - except: - import ipdb - ipdb.set_trace() + pooled = masked_mean(last_token_embeds, mask, dim = 1) reward = self.pred_reward(pooled) return reward + + def forward(self, prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask): + prefer_reward = self.single_forward(prefer_x, prefer_x_prompt_mask) + alter_reward = self.single_forward(alter_x, alter_x_prompt_mask) + + return prefer_reward, alter_reward + + def training_step(self, batch, batch_idx): + prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask = batch + prefer_reward, alter_reward = self( + prefer_x, alter_x, prefer_x_prompt_mask, alter_x_prompt_mask) + + loss = loss_function(prefer_reward, alter_reward) + + return loss + + diff --git a/src/rlhf/rwkv/model.py b/src/rlhf/rwkv/model.py index f55e033..751b4c8 100644 --- a/src/rlhf/rwkv/model.py +++ b/src/rlhf/rwkv/model.py @@ -365,7 +365,7 @@ class RWKV(MyModule): ######################################################################################################## - def forward(self, tokens, state, full_output=False, extra_embed=None): + def forward(self, tokens, state, full_output=False, extra_embed=None, rm_train=False): with torch.no_grad(): w = self.w args = self.args @@ -384,8 +384,6 @@ class RWKV(MyModule): state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev) seq_mode = len(tokens) > 1 - import ipdb - ipdb.set_trace() # 输入:根据 idx 取每个 token 的 embedding x = w['emb.weight'][tokens if seq_mode else tokens[0]] @@ -460,11 +458,12 @@ class RWKV(MyModule): # 对 token embedding 进行 LayerNorm,维度不变 x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias']) - token_embed = copy.deepcopy(x) + if rm_train is True: + return x if w['head.weight'].dtype != torch.uint8: x = x @ w['head.weight'] else: x = x @ self.get_w('head.weight', dd.atype) - return x.float(), state, token_embed.float() + return x.float(), state diff --git a/train_rm.py b/train_rm.py index ed6ecea..ee87fb9 100644 --- a/train_rm.py +++ b/train_rm.py @@ -220,45 +220,44 @@ if __name__ == "__main__": ######################################################################################################## + # 训练 RM 模型 + def loss_function(prefer_reward, alter_reward): + return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward))) + import torch + from tqdm import tqdm + from src.trainer import train_callback + from src.rlhf.reward import RewardModel + from src.model import RWKV + from src.dataset import RMDataset + # 读入训练数据 + train_data = RMDataset(args) + args.vocab_size = train_data.vocab_size + # RM 模型 + rm_model = RewardModel(args) -import torch - -from src.rlhf.reward import RewardModel -from src.rlhf.rwkv.model import RWKV - -model = "./model/RWKV-4-Pile-169M-20220807-8023.pth" -strategy = "cpu fp32" -rwkv_model = RWKV(model, strategy) -dim = rwkv_model.args.n_embd - -reward_model = RewardModel( - rwkv_model -) - -# mock data -prompt = torch.randint(0, dim, (1, 50)) -prefer_response = torch.randint(0, dim, (1, 50)) -alter_response = torch.randint(0, dim, (1, 50)) + # 训练 + trainer = Trainer.from_argparse_args() -prefer_pair = torch.concat((prompt, prefer_response), dim=1) -alter_pair = torch.concat((prompt, alter_response), dim=1) + if trainer.global_rank == 0: + for n in rm_model.state_dict(): + shape = rm_model.state_dict()[n].shape + shape = [i for i in shape if i != 1] + if len(shape) > 1: + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") + else: + print(f"{str(shape[0]).ljust(5)} {n}") -# which part of the sequence is prompt, which part is response -prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1) -# labels = torch.randint(0, 5, (1,)) + if "deepspeed" in args.strategy: + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 -# train -# loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels) -# loss.backward() + # must set shuffle=True, persistent_workers=False (because worker is in another thread) + data_loader = DataLoader(train_data, shuffle=True, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) -# inference -prefer_reward = reward_model(prefer_pair, prompt_mask = prompt_mask) -alter_reward = reward_model(alter_pair, prompt_mask = prompt_mask) + trainer.fit(rm_model, data_loader) -print("Preferred response reward:", prefer_reward) -print("Alternate response reward:", alter_reward) \ No newline at end of file diff --git a/train_rm_demo.py b/train_rm_demo.py new file mode 100644 index 0000000..3b2a21e --- /dev/null +++ b/train_rm_demo.py @@ -0,0 +1,62 @@ +''' +@File : train_rm_demo.py +@Time : 2023/03/10 00:54:57 +@Author : Lu Xin +@Contact : luxin@csdn.net +''' + +# here put the import lib + +import torch + +from tqdm import tqdm + +from src.rlhf.reward import RewardModel +from src.rlhf.rwkv.model import RWKV + +def loss_function(prefer_reward, alter_reward): + return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward))) + +model = "./model/RWKV-4-Pile-169M-20220807-8023.pth" +strategy = "cpu fp32" +rwkv_model = RWKV(model, strategy) + +reward_model = RewardModel( + rwkv_model +) + +import ipdb +ipdb.set_trace() + +# as used in the InstructGPT paper +optimizer = torch.optim.Adam( + reward_model.parameters(), lr=1e-5, betas=(0.9, 0.95)) + +# 假数据 +dim = 20000 +prompt = torch.randint(0, dim, (1, 50)) +prefer_response = torch.randint(0, dim, (1, 50)) +alter_response = torch.randint(0, dim, (1, 50)) + +prefer_pair = torch.concat((prompt, prefer_response), dim=1) +alter_pair = torch.concat((prompt, alter_response), dim=1) + +prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1) + +for epoch in range(100): + # 计算奖励 + prefer_reward = reward_model(prefer_pair, prompt_mask = prompt_mask) + alter_reward = reward_model(alter_pair, prompt_mask = prompt_mask) + # print(f"prefer_reward: {prefer_reward}") + # print(f"alter_reward: {alter_reward}") + + # train + loss = loss_function(prefer_reward, alter_reward) + print(f"loss: {loss}") + + # Backward pass + loss.backward() + optimizer.step() + + # Zero the gradients + optimizer.zero_grad() -- GitLab