提交 5bac191b 编写于 作者: U u010280923

bug fixed

上级 142305ef
......@@ -18,7 +18,6 @@ from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from src.rlhf.utils import masked_mean, gumbel_sample
# from src.model import RWKV
from src.model import RWKV
# helper functions
......@@ -28,7 +27,7 @@ def exists(val):
# loss function
def loss_function(prefer_reward, alter_reward):
return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward)))
return -torch.mean(torch.log(torch.sigmoid(prefer_reward - alter_reward)))
# Reward Model - RWKV with a scalar head
......@@ -195,4 +194,9 @@ class RewardModel(pl.LightningModule):
return loss
def training_step_end(self, batch_parts):
all = self.all_gather(batch_parts)
if self.trainer.is_global_zero:
self.trainer.my_loss_all = all
......@@ -221,9 +221,6 @@ if __name__ == "__main__":
########################################################################################################
# 训练 RM 模型
def loss_function(prefer_reward, alter_reward):
return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward)))
import torch
from tqdm import tqdm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册