提交 c13e797d 编写于 作者: U u010280923

bug fixed

上级 3b7e41b2
......@@ -45,6 +45,7 @@ class ActorCritic(nn.Module):
def __init__(
self,
rwkv: RWKV,
args,
critic: Optional[RWKV] = None,
pooled_values = False
):
......@@ -58,7 +59,7 @@ class ActorCritic(nn.Module):
self.pooled_values = pooled_values
self.value_head = nn.Sequential(
nn.Linear(rwkv.dim, 1),
nn.Linear(args.n_embd, 1),
Rearrange('... 1 -> ...')
)
......@@ -284,6 +285,7 @@ class RLHF(nn.Module):
# 使用 RWKV 初始化 actor_critic
actor_critic = ActorCritic(
rwkv = self.rwkv,
args = self.args,
pooled_values = args.critic_pooled_values
).to(self.rwkv.device)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册