diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 0b8f6520dadefd1341bc28aeacebed288c6bad59..cd03bef515333cf505158b39e8ae793ca70c3552 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -181,8 +181,10 @@ class RewardModel(pl.LightningModule): return reward def forward(self, x_p, x_a, m_p, m_a): - prefer_reward = self.single_forward(x_p, prompt_mask=m_p) - alter_reward = self.single_forward(x_a, prompt_mask=m_a) + with torch.enable_grad(): + prefer_reward = self.single_forward(x_p, prompt_mask=m_p) + with torch.no_grad(): + alter_reward = self.single_forward(x_a, prompt_mask=m_a) return prefer_reward, alter_reward