From 9463b004be49ceb66c23f504872c1451c15f170d Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 20 Mar 2023 17:49:19 +0800 Subject: [PATCH] opt ppo model --- src/rlhf/ppo.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 0e1dc65..f815eae 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -323,8 +323,10 @@ class RLHF(pl.LightningModule): def generate( self, max_seq_len, + *args, prompt, - num_samples = 4 # sample 4 per prompt and select the one with highest reward + num_samples = 4, # sample 4 per prompt and select the one with highest reward + **kwargs ): ''' 未参与训练,仅推理时使用 ''' @@ -342,8 +344,10 @@ class RLHF(pl.LightningModule): _ ) = self.actor_critic.generate( prompt, + *args, max_seq_len = max_seq_len, - return_values = False + return_values = False, + **kwargs ) rewards = self.reward_model( -- GitLab