提交 5bac191b 编写于 作者: U u010280923

bug fixed

上级 142305ef
...@@ -18,7 +18,6 @@ from einops import rearrange, repeat, reduce, pack, unpack ...@@ -18,7 +18,6 @@ from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce from einops.layers.torch import Rearrange, Reduce
from src.rlhf.utils import masked_mean, gumbel_sample from src.rlhf.utils import masked_mean, gumbel_sample
# from src.model import RWKV
from src.model import RWKV from src.model import RWKV
# helper functions # helper functions
...@@ -28,7 +27,7 @@ def exists(val): ...@@ -28,7 +27,7 @@ def exists(val):
# loss function # loss function
def loss_function(prefer_reward, alter_reward): def loss_function(prefer_reward, alter_reward):
return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward))) return -torch.mean(torch.log(torch.sigmoid(prefer_reward - alter_reward)))
# Reward Model - RWKV with a scalar head # Reward Model - RWKV with a scalar head
...@@ -195,4 +194,9 @@ class RewardModel(pl.LightningModule): ...@@ -195,4 +194,9 @@ class RewardModel(pl.LightningModule):
return loss return loss
def training_step_end(self, batch_parts):
all = self.all_gather(batch_parts)
if self.trainer.is_global_zero:
self.trainer.my_loss_all = all
...@@ -221,9 +221,6 @@ if __name__ == "__main__": ...@@ -221,9 +221,6 @@ if __name__ == "__main__":
######################################################################################################## ########################################################################################################
# 训练 RM 模型 # 训练 RM 模型
def loss_function(prefer_reward, alter_reward):
return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward)))
import torch import torch
from tqdm import tqdm from tqdm import tqdm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册