diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index ae2ef5fa0b65c4a26a5244c0281b3e7336a1ece0..8652895efbbfa0c0fe13cabc31f85dc7fc7c7471 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -45,6 +45,7 @@ class ActorCritic(nn.Module): def __init__( self, rwkv: RWKV, + args, critic: Optional[RWKV] = None, pooled_values = False ): @@ -58,7 +59,7 @@ class ActorCritic(nn.Module): self.pooled_values = pooled_values self.value_head = nn.Sequential( - nn.Linear(rwkv.dim, 1), + nn.Linear(args.n_embd, 1), Rearrange('... 1 -> ...') ) @@ -284,6 +285,7 @@ class RLHF(nn.Module): # 使用 RWKV 初始化 actor_critic actor_critic = ActorCritic( rwkv = self.rwkv, + args = self.args, pooled_values = args.critic_pooled_values ).to(self.rwkv.device)