diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index f815eaec130935cfffb965c5cd3d11653b65a495..ef8fab17fd67d663dd586aba9ed4aea6ac8b44b2 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -323,10 +323,8 @@ class RLHF(pl.LightningModule): def generate( self, max_seq_len, - *args, prompt, - num_samples = 4, # sample 4 per prompt and select the one with highest reward - **kwargs + num_samples = 4 # sample 4 per prompt and select the one with highest reward ): ''' 未参与训练,仅推理时使用 ''' @@ -344,10 +342,8 @@ class RLHF(pl.LightningModule): _ ) = self.actor_critic.generate( prompt, - *args, max_seq_len = max_seq_len, - return_values = False, - **kwargs + return_values = False ) rewards = self.reward_model( @@ -468,7 +464,6 @@ class RLHF(pl.LightningModule): rearrange(state, 'n -> 1 n'), max_seq_len = self.args.ctx_len, eos_token = eos_token, - temperature = temperature, return_values = True ) action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token