diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 0e1dc65646e94da194a32350a20cafd03c8433b2..f815eaec130935cfffb965c5cd3d11653b65a495 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -323,8 +323,10 @@ class RLHF(pl.LightningModule): def generate( self, max_seq_len, + *args, prompt, - num_samples = 4 # sample 4 per prompt and select the one with highest reward + num_samples = 4, # sample 4 per prompt and select the one with highest reward + **kwargs ): ''' 未参与训练,仅推理时使用 ''' @@ -342,8 +344,10 @@ class RLHF(pl.LightningModule): _ ) = self.actor_critic.generate( prompt, + *args, max_seq_len = max_seq_len, - return_values = False + return_values = False, + **kwargs ) rewards = self.reward_model(