提交 dfeee746 编写于 作者: U u010280923

opt reward model

上级 06865f46
......@@ -68,7 +68,7 @@ class RewardModel(pl.LightningModule):
# reward 得分计算
self.pred_reward = nn.Sequential(
nn.Linear(dim, 1),
nn.Linear(dim, 1, bias=False),
Rearrange('... 1 -> ...') # 降维
)
......
......@@ -240,7 +240,10 @@ if __name__ == "__main__":
rm_model = RewardModel(args)
# 训练
trainer = Trainer.from_argparse_args()
trainer = Trainer.from_argparse_args(
args,
callbacks=[train_callback(args)],
)
if trainer.global_rank == 0:
for n in rm_model.state_dict():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册