diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 8652895efbbfa0c0fe13cabc31f85dc7fc7c7471..41137e02cb924f011be9ea160762686770943307 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -250,7 +250,7 @@ def clipped_value_loss(values, rewards, old_values, clip): # rlhf @beartype -class RLHF(nn.Module): +class RLHF(pl.LightningModule): def __init__( self, args diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 580ac19ff24b423e2f50a64a7fa940008adbf25c..639b2355f3db5cf2d6f4b68d9543143725f92324 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -21,6 +21,7 @@ from einops.layers.torch import Rearrange, Reduce from src.rlhf.utils import masked_mean, gumbel_sample from src.model import RWKV + # helper functions def exists(val): @@ -34,30 +35,9 @@ def loss_function(prefer_reward, alter_reward): @beartype class RewardModel(pl.LightningModule): - def __init__(self, args): + def __init__(self, args, rwkv: RWKV): super().__init__() - # 加载 RWKV 模型 - rwkv = RWKV(args) - - if len(args.load_model) == 0: - rank_zero_info(f"SFT must load model, please input ") - exit(1) - - rank_zero_info(f"########## Loading {args.load_model}... ##########") - try: - load_dict = torch.load(args.load_model, map_location="cpu") - except: - rank_zero_info(f"Bad checkpoint {args.load_model}") - exit(1) - - if args.load_partial == 1: - load_keys = load_dict.keys() - for k in rwkv.state_dict(): - if k not in load_keys: - load_dict[k] = rwkv.state_dict()[k] - rwkv.load_state_dict(load_dict) - # 用预训练模型初始化奖励模型 self.rwkv = rwkv self.args = args diff --git a/train_rm.py b/train_rm.py index 9e980be703ae78763aca33e1bcb69a5ca3fd5bd3..59e3cf220cd4630aa9049e6c52e28f28cb7c9d45 100644 --- a/train_rm.py +++ b/train_rm.py @@ -57,7 +57,7 @@ if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--load_model", default="", type=str) # full path, with .pth + parser.add_argument("--load_sft_model", default="", type=str) # full path, with .pth parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb parser.add_argument("--proj_dir", default="out", type=str) parser.add_argument("--random_seed", default="-1", type=int) @@ -228,13 +228,35 @@ if __name__ == "__main__": from src.trainer import rm_train_callback from src.rlhf.reward import RewardModel from src.dataset import RMDataset + from src.model import RWKV # 读入训练数据 train_data = RMDataset(args) args.vocab_size = train_data.vocab_size - # RM 模型 - rm_model = RewardModel(args) + # 加载 RWKV 模型 + rwkv = RWKV(args) + + if len(args.load_sft_model) == 0: + rank_zero_info(f"SFT must load model, please input ") + exit(1) + + rank_zero_info(f"########## Loading {args.load_sft_model}... ##########") + try: + load_dict = torch.load(args.load_sft_model, map_location="cpu") + except: + rank_zero_info(f"Bad checkpoint {args.load_sft_model}") + exit(1) + + if args.load_partial == 1: + load_keys = load_dict.keys() + for k in rwkv.state_dict(): + if k not in load_keys: + load_dict[k] = rwkv.state_dict()[k] + rwkv.load_state_dict(load_dict) + + # 初始化 RM 模型 + rm_model = RewardModel(args, rwkv) # 训练 trainer = Trainer.from_argparse_args(