提交 8aabd2fc 编写于 作者: U u010280923

opt ppo model

上级 9463b004
...@@ -323,10 +323,8 @@ class RLHF(pl.LightningModule): ...@@ -323,10 +323,8 @@ class RLHF(pl.LightningModule):
def generate( def generate(
self, self,
max_seq_len, max_seq_len,
*args,
prompt, prompt,
num_samples = 4, # sample 4 per prompt and select the one with highest reward num_samples = 4 # sample 4 per prompt and select the one with highest reward
**kwargs
): ):
''' 未参与训练,仅推理时使用 ''' 未参与训练,仅推理时使用
''' '''
...@@ -344,10 +342,8 @@ class RLHF(pl.LightningModule): ...@@ -344,10 +342,8 @@ class RLHF(pl.LightningModule):
_ _
) = self.actor_critic.generate( ) = self.actor_critic.generate(
prompt, prompt,
*args,
max_seq_len = max_seq_len, max_seq_len = max_seq_len,
return_values = False, return_values = False
**kwargs
) )
rewards = self.reward_model( rewards = self.reward_model(
...@@ -468,7 +464,6 @@ class RLHF(pl.LightningModule): ...@@ -468,7 +464,6 @@ class RLHF(pl.LightningModule):
rearrange(state, 'n -> 1 n'), rearrange(state, 'n -> 1 n'),
max_seq_len = self.args.ctx_len, max_seq_len = self.args.ctx_len,
eos_token = eos_token, eos_token = eos_token,
temperature = temperature,
return_values = True return_values = True
) )
action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册