diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 1b603597a0bc8d5f1cea1efd9bf4c6e4028429a4..1b1ba9dc3125c180a8cab00c13ad606f96e91011 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -31,11 +31,7 @@ def loss_function(prefer_reward, alter_reward): @beartype class RewardModel(pl.LightningModule): - def __init__( - self, - args, - rwkv: RWKV - ): + def __init__(self, args): super().__init__() # 加载 RWKV 模型 diff --git a/train_rm.py b/train_rm.py index ee87fb95dd28f457bc1399247ecf7c13f6335329..94d206540612c3d1e5cebc44c095fc095eaa680e 100644 --- a/train_rm.py +++ b/train_rm.py @@ -230,7 +230,6 @@ if __name__ == "__main__": from src.trainer import train_callback from src.rlhf.reward import RewardModel - from src.model import RWKV from src.dataset import RMDataset # 读入训练数据