diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index b0a3824dbefa3140deb3aa4ffe43fa0978b657fa..2318dd3c4760c960e4923ee1a85adaa880303ec0 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -18,7 +18,6 @@ from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange, Reduce from src.rlhf.utils import masked_mean, gumbel_sample -# from src.model import RWKV from src.model import RWKV # helper functions @@ -28,7 +27,7 @@ def exists(val): # loss function def loss_function(prefer_reward, alter_reward): - return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward))) + return -torch.mean(torch.log(torch.sigmoid(prefer_reward - alter_reward))) # Reward Model - RWKV with a scalar head @@ -194,5 +193,10 @@ class RewardModel(pl.LightningModule): loss = loss_function(prefer_reward, alter_reward) return loss + + def training_step_end(self, batch_parts): + all = self.all_gather(batch_parts) + if self.trainer.is_global_zero: + self.trainer.my_loss_all = all diff --git a/train_rm.py b/train_rm.py index b0ef71aa64c5a5baac30265cf5b09bb1d9819e0a..84f61439bbbb02183f4e55ccefb226183a1b5de4 100644 --- a/train_rm.py +++ b/train_rm.py @@ -221,9 +221,6 @@ if __name__ == "__main__": ######################################################################################################## # 训练 RM 模型 - def loss_function(prefer_reward, alter_reward): - return -torch.mean(torch.log(torch.sigmoid(alter_reward - prefer_reward))) - import torch from tqdm import tqdm