提交 26b8d465 编写于 作者: U u010280923

opt reward model

上级 82d6d979
......@@ -250,7 +250,7 @@ def clipped_value_loss(values, rewards, old_values, clip):
# rlhf
@beartype
class RLHF(nn.Module):
class RLHF(pl.LightningModule):
def __init__(
self,
args
......
......@@ -21,6 +21,7 @@ from einops.layers.torch import Rearrange, Reduce
from src.rlhf.utils import masked_mean, gumbel_sample
from src.model import RWKV
# helper functions
def exists(val):
......@@ -34,30 +35,9 @@ def loss_function(prefer_reward, alter_reward):
@beartype
class RewardModel(pl.LightningModule):
def __init__(self, args):
def __init__(self, args, rwkv: RWKV):
super().__init__()
# 加载 RWKV 模型
rwkv = RWKV(args)
if len(args.load_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
rank_zero_info(f"########## Loading {args.load_model}... ##########")
try:
load_dict = torch.load(args.load_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
# 用预训练模型初始化奖励模型
self.rwkv = rwkv
self.args = args
......
......@@ -57,7 +57,7 @@ if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
parser.add_argument("--load_sft_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int)
......@@ -228,13 +228,35 @@ if __name__ == "__main__":
from src.trainer import rm_train_callback
from src.rlhf.reward import RewardModel
from src.dataset import RMDataset
from src.model import RWKV
# 读入训练数据
train_data = RMDataset(args)
args.vocab_size = train_data.vocab_size
# RM 模型
rm_model = RewardModel(args)
# 加载 RWKV 模型
rwkv = RWKV(args)
if len(args.load_sft_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
rank_zero_info(f"########## Loading {args.load_sft_model}... ##########")
try:
load_dict = torch.load(args.load_sft_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_sft_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
# 初始化 RM 模型
rm_model = RewardModel(args, rwkv)
# 训练
trainer = Trainer.from_argparse_args(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册