train_rm_demo.py 1.5 KB
Newer Older
U
u010280923 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
'''
@File    :   train_rm_demo.py
@Time    :   2023/03/10 00:54:57
@Author  :   Lu Xin 
@Contact :   luxin@csdn.net
'''

# here put the import lib

import torch

from tqdm import tqdm

from src.rlhf.reward import RewardModel
from src.rlhf.rwkv.model import RWKV

def loss_function(prefer_reward, alter_reward):
    return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward)))

model = "./model/RWKV-4-Pile-169M-20220807-8023.pth"
strategy = "cpu fp32"
rwkv_model = RWKV(model, strategy)

reward_model = RewardModel(
    rwkv_model
)

import ipdb
ipdb.set_trace()

# as used in the InstructGPT paper
optimizer = torch.optim.Adam(
    reward_model.parameters(), lr=1e-5, betas=(0.9, 0.95)) 

# 假数据
dim = 20000
prompt = torch.randint(0, dim, (1, 50))
prefer_response = torch.randint(0, dim, (1, 50))   
alter_response = torch.randint(0, dim, (1, 50))

prefer_pair = torch.concat((prompt, prefer_response), dim=1)
alter_pair = torch.concat((prompt, alter_response), dim=1)

prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1)

for epoch in range(100):
    # 计算奖励
    prefer_reward = reward_model(prefer_pair, prompt_mask = prompt_mask)
    alter_reward = reward_model(alter_pair, prompt_mask = prompt_mask)
    # print(f"prefer_reward: {prefer_reward}")
    # print(f"alter_reward: {alter_reward}")

    # train
    loss = loss_function(prefer_reward, alter_reward)
    print(f"loss: {loss}")

    # Backward pass
    loss.backward()
    optimizer.step()

    # Zero the gradients
    optimizer.zero_grad()