提交 26b8d465 编写于 作者: U u010280923

opt reward model

上级 82d6d979
...@@ -250,7 +250,7 @@ def clipped_value_loss(values, rewards, old_values, clip): ...@@ -250,7 +250,7 @@ def clipped_value_loss(values, rewards, old_values, clip):
# rlhf # rlhf
@beartype @beartype
class RLHF(nn.Module): class RLHF(pl.LightningModule):
def __init__( def __init__(
self, self,
args args
......
...@@ -21,6 +21,7 @@ from einops.layers.torch import Rearrange, Reduce ...@@ -21,6 +21,7 @@ from einops.layers.torch import Rearrange, Reduce
from src.rlhf.utils import masked_mean, gumbel_sample from src.rlhf.utils import masked_mean, gumbel_sample
from src.model import RWKV from src.model import RWKV
# helper functions # helper functions
def exists(val): def exists(val):
...@@ -34,30 +35,9 @@ def loss_function(prefer_reward, alter_reward): ...@@ -34,30 +35,9 @@ def loss_function(prefer_reward, alter_reward):
@beartype @beartype
class RewardModel(pl.LightningModule): class RewardModel(pl.LightningModule):
def __init__(self, args): def __init__(self, args, rwkv: RWKV):
super().__init__() super().__init__()
# 加载 RWKV 模型
rwkv = RWKV(args)
if len(args.load_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
rank_zero_info(f"########## Loading {args.load_model}... ##########")
try:
load_dict = torch.load(args.load_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
# 用预训练模型初始化奖励模型 # 用预训练模型初始化奖励模型
self.rwkv = rwkv self.rwkv = rwkv
self.args = args self.args = args
......
...@@ -57,7 +57,7 @@ if __name__ == "__main__": ...@@ -57,7 +57,7 @@ if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--load_model", default="", type=str) # full path, with .pth parser.add_argument("--load_sft_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str) parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int) parser.add_argument("--random_seed", default="-1", type=int)
...@@ -228,13 +228,35 @@ if __name__ == "__main__": ...@@ -228,13 +228,35 @@ if __name__ == "__main__":
from src.trainer import rm_train_callback from src.trainer import rm_train_callback
from src.rlhf.reward import RewardModel from src.rlhf.reward import RewardModel
from src.dataset import RMDataset from src.dataset import RMDataset
from src.model import RWKV
# 读入训练数据 # 读入训练数据
train_data = RMDataset(args) train_data = RMDataset(args)
args.vocab_size = train_data.vocab_size args.vocab_size = train_data.vocab_size
# RM 模型 # 加载 RWKV 模型
rm_model = RewardModel(args) rwkv = RWKV(args)
if len(args.load_sft_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
rank_zero_info(f"########## Loading {args.load_sft_model}... ##########")
try:
load_dict = torch.load(args.load_sft_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_sft_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
# 初始化 RM 模型
rm_model = RewardModel(args, rwkv)
# 训练 # 训练
trainer = Trainer.from_argparse_args( trainer = Trainer.from_argparse_args(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册