From 5bac191b223e2e6c5a8f13a36c4bb181cfa847bf Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 18:29:52 +0800 Subject: [PATCH] bug fixed --- src/rlhf/reward.py | 8 ++++++-- train_rm.py | 3 --- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index b0a3824..2318dd3 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -18,7 +18,6 @@ from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange, Reduce from src.rlhf.utils import masked_mean, gumbel_sample -# from src.model import RWKV from src.model import RWKV # helper functions @@ -28,7 +27,7 @@ def exists(val): # loss function def loss_function(prefer_reward, alter_reward): - return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward))) + return -torch.mean(torch.log(torch.sigmoid(prefer_reward - alter_reward))) # Reward Model - RWKV with a scalar head @@ -194,5 +193,10 @@ class RewardModel(pl.LightningModule): loss = loss_function(prefer_reward, alter_reward) return loss + + def training_step_end(self, batch_parts): + all = self.all_gather(batch_parts) + if self.trainer.is_global_zero: + self.trainer.my_loss_all = all diff --git a/train_rm.py b/train_rm.py index b0ef71a..84f6143 100644 --- a/train_rm.py +++ b/train_rm.py @@ -221,9 +221,6 @@ if __name__ == "__main__": ######################################################################################################## # 训练 RM 模型 - def loss_function(prefer_reward, alter_reward): - return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward))) - import torch from tqdm import tqdm -- GitLab