提交 68662e23 编写于 作者: U u010280923

opt reward model

上级 f7516c41
...@@ -181,8 +181,10 @@ class RewardModel(pl.LightningModule): ...@@ -181,8 +181,10 @@ class RewardModel(pl.LightningModule):
return reward return reward
def forward(self, x_p, x_a, m_p, m_a): def forward(self, x_p, x_a, m_p, m_a):
prefer_reward = self.single_forward(x_p, prompt_mask=m_p) with torch.enable_grad():
alter_reward = self.single_forward(x_a, prompt_mask=m_a) prefer_reward = self.single_forward(x_p, prompt_mask=m_p)
with torch.no_grad():
alter_reward = self.single_forward(x_a, prompt_mask=m_a)
return prefer_reward, alter_reward return prefer_reward, alter_reward
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册