''' @File : train_rm_demo.py @Time : 2023/03/10 00:54:57 @Author : Lu Xin @Contact : luxin@csdn.net ''' # here put the import lib import torch from tqdm import tqdm from src.rlhf.reward import RewardModel from src.rlhf.rwkv.model import RWKV def loss_function(prefer_reward, alter_reward): return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward))) model = "./model/RWKV-4-Pile-169M-20220807-8023.pth" strategy = "cpu fp32" rwkv_model = RWKV(model, strategy) reward_model = RewardModel( rwkv_model ) import ipdb ipdb.set_trace() # as used in the InstructGPT paper optimizer = torch.optim.Adam( reward_model.parameters(), lr=1e-5, betas=(0.9, 0.95)) # 假数据 dim = 20000 prompt = torch.randint(0, dim, (1, 50)) prefer_response = torch.randint(0, dim, (1, 50)) alter_response = torch.randint(0, dim, (1, 50)) prefer_pair = torch.concat((prompt, prefer_response), dim=1) alter_pair = torch.concat((prompt, alter_response), dim=1) prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1) for epoch in range(100): # 计算奖励 prefer_reward = reward_model(prefer_pair, prompt_mask = prompt_mask) alter_reward = reward_model(alter_pair, prompt_mask = prompt_mask) # print(f"prefer_reward: {prefer_reward}") # print(f"alter_reward: {alter_reward}") # train loss = loss_function(prefer_reward, alter_reward) print(f"loss: {loss}") # Backward pass loss.backward() optimizer.step() # Zero the gradients optimizer.zero_grad()