提交 9463b004 编写于 作者: U u010280923

opt ppo model

上级 cf2bc522
......@@ -323,8 +323,10 @@ class RLHF(pl.LightningModule):
def generate(
self,
max_seq_len,
*args,
prompt,
num_samples = 4 # sample 4 per prompt and select the one with highest reward
num_samples = 4, # sample 4 per prompt and select the one with highest reward
**kwargs
):
''' 未参与训练,仅推理时使用
'''
......@@ -342,8 +344,10 @@ class RLHF(pl.LightningModule):
_
) = self.actor_critic.generate(
prompt,
*args,
max_seq_len = max_seq_len,
return_values = False
return_values = False,
**kwargs
)
rewards = self.reward_model(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册