From c13e797dadfe888204ffaef76ac11aef95becd35 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 17 Mar 2023 18:03:51 +0800 Subject: [PATCH] bug fixed --- src/rlhf/ppo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index ae2ef5f..8652895 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -45,6 +45,7 @@ class ActorCritic(nn.Module): def __init__( self, rwkv: RWKV, + args, critic: Optional[RWKV] = None, pooled_values = False ): @@ -58,7 +59,7 @@ class ActorCritic(nn.Module): self.pooled_values = pooled_values self.value_head = nn.Sequential( - nn.Linear(rwkv.dim, 1), + nn.Linear(args.n_embd, 1), Rearrange('... 1 -> ...') ) @@ -284,6 +285,7 @@ class RLHF(nn.Module): # 使用 RWKV 初始化 actor_critic actor_critic = ActorCritic( rwkv = self.rwkv, + args = self.args, pooled_values = args.critic_pooled_values ).to(self.rwkv.device) -- GitLab