From 68662e236d64730b28a1b909c267f4da092a3f04 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 13 Mar 2023 10:32:51 +0800 Subject: [PATCH] opt reward model --- src/rlhf/reward.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 0b8f652..cd03bef 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -181,8 +181,10 @@ class RewardModel(pl.LightningModule): return reward def forward(self, x_p, x_a, m_p, m_a): - prefer_reward = self.single_forward(x_p, prompt_mask=m_p) - alter_reward = self.single_forward(x_a, prompt_mask=m_a) + with torch.enable_grad(): + prefer_reward = self.single_forward(x_p, prompt_mask=m_p) + with torch.no_grad(): + alter_reward = self.single_forward(x_a, prompt_mask=m_a) return prefer_reward, alter_reward -- GitLab