From 8aabd2fc278023038e41595f46c424fc20b2ff6e Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 20 Mar 2023 18:17:00 +0800 Subject: [PATCH] opt ppo model --- src/rlhf/ppo.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index f815eae..ef8fab1 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -323,10 +323,8 @@ class RLHF(pl.LightningModule): def generate( self, max_seq_len, - *args, prompt, - num_samples = 4, # sample 4 per prompt and select the one with highest reward - **kwargs + num_samples = 4 # sample 4 per prompt and select the one with highest reward ): ''' 未参与训练,仅推理时使用 ''' @@ -344,10 +342,8 @@ class RLHF(pl.LightningModule): _ ) = self.actor_critic.generate( prompt, - *args, max_seq_len = max_seq_len, - return_values = False, - **kwargs + return_values = False ) rewards = self.reward_model( @@ -468,7 +464,6 @@ class RLHF(pl.LightningModule): rearrange(state, 'n -> 1 n'), max_seq_len = self.args.ctx_len, eos_token = eos_token, - temperature = temperature, return_values = True ) action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token -- GitLab