提交 06865f46 编写于 作者: U u010280923

opt reward model

上级 3c80d013
...@@ -31,11 +31,7 @@ def loss_function(prefer_reward, alter_reward): ...@@ -31,11 +31,7 @@ def loss_function(prefer_reward, alter_reward):
@beartype @beartype
class RewardModel(pl.LightningModule): class RewardModel(pl.LightningModule):
def __init__( def __init__(self, args):
self,
args,
rwkv: RWKV
):
super().__init__() super().__init__()
# 加载 RWKV 模型 # 加载 RWKV 模型
......
...@@ -230,7 +230,6 @@ if __name__ == "__main__": ...@@ -230,7 +230,6 @@ if __name__ == "__main__":
from src.trainer import train_callback from src.trainer import train_callback
from src.rlhf.reward import RewardModel from src.rlhf.reward import RewardModel
from src.model import RWKV
from src.dataset import RMDataset from src.dataset import RMDataset
# 读入训练数据 # 读入训练数据
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册