From 06865f467e837720258147b7b8fbfea1fc537c75 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 10 Mar 2023 13:57:34 +0800 Subject: [PATCH] opt reward model --- src/rlhf/reward.py | 6 +----- train_rm.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 1b60359..1b1ba9d 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -31,11 +31,7 @@ def loss_function(prefer_reward, alter_reward): @beartype class RewardModel(pl.LightningModule): - def __init__( - self, - args, - rwkv: RWKV - ): + def __init__(self, args): super().__init__() # 加载 RWKV 模型 diff --git a/train_rm.py b/train_rm.py index ee87fb9..94d2065 100644 --- a/train_rm.py +++ b/train_rm.py @@ -230,7 +230,6 @@ if __name__ == "__main__": from src.trainer import train_callback from src.rlhf.reward import RewardModel - from src.model import RWKV from src.dataset import RMDataset # 读入训练数据 -- GitLab