提交 06865f46 编写于 作者: U u010280923

opt reward model

上级 3c80d013
......@@ -31,11 +31,7 @@ def loss_function(prefer_reward, alter_reward):
@beartype
class RewardModel(pl.LightningModule):
def __init__(
self,
args,
rwkv: RWKV
):
def __init__(self, args):
super().__init__()
# 加载 RWKV 模型
......
......@@ -230,7 +230,6 @@ if __name__ == "__main__":
from src.trainer import train_callback
from src.rlhf.reward import RewardModel
from src.model import RWKV
from src.dataset import RMDataset
# 读入训练数据
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册