''' @File : train_rm.py @Time : 2023/03/08 15:23:29 @Author : Lu Xin @Contact : luxin@csdn.net ''' # here put the import lib import torch from src.rlhf.reward import RewardModel from src.model import RWKV rwkv_model = RWKV() reward_model = RewardModel( rwkv_model, num_binned_output = 5 # 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级 ) # mock data seq = torch.randint(0, 20000, (1, 100)) # prompt_mask = torch.zeros(1, 100).bool() # which part of the sequence is prompt, which part is response prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1) labels = torch.randint(0, 5, (1,)) # train loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels) loss.backward() # after much training reward = reward_model(seq, prompt_mask = prompt_mask)