''' @File : train_rlhf.py @Time : 2023/03/08 15:23:19 @Author : Lu Xin @Contact : luxin@csdn.net ''' # here put the import lib import torch from src.model import RWKV from src.rlhf.reward import RewardModel from src.rlhf.ppo import RLHFTrainer # load your pretrained RWKV # todo(luxin) 加载 SFT 之后的预训练模型 rwkv_model = RWKV() # palm.load('./path/to/pretrained/palm.pt') # load your pretrained reward model # todo(luxin) 加载训练好的 reward Model reward_model = RewardModel( rwkv_model, num_binned_output = 5 ) # reward_model.load('./path/to/pretrained/reward_model.pt') # ready your list of prompts for reinforcement learning # todo(luxin) 读入 Prompts 数据集(此处的 Prompt 与 SFT、RM 阶段的 Prompt 要不一样) prompts = torch.randint(0, 256, (50000, 512)) # 50k prompts # pass it all to the trainer and train # 训练 PPO 模型 trainer = RLHFTrainer( palm = palm, reward_model = reward_model, prompt_token_ids = prompts ) trainer.train(num_episodes = 100) # then, if it succeeded... # generate say 10 samples and use the reward model to return the best one answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,) print(answer)