提交 835e3c4c 编写于 作者: N niuyazhe

fix(nyz): fix qmix double_q hidden state bug

上级 e22e5e43
......@@ -121,7 +121,7 @@ class QMix(nn.Module):
self.mixer = mixer
if self.mixer:
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size)
self._global_state_encoder = nn.Sequential()
self._global_state_encoder = nn.Identity()
def forward(self, data: dict, single_step: bool = True) -> dict:
"""
......@@ -428,7 +428,7 @@ class CollaQ(nn.Module):
embedding_size = hidden_size_list[-1]
if self.mixer:
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size)
self._global_state_encoder = nn.Sequential()
self._global_state_encoder = nn.Identity()
def forward(self, data: dict, single_step: bool = True) -> dict:
"""
......
......@@ -197,6 +197,7 @@ class QMIXPolicy(Policy):
if self._cfg.learn.double_q:
next_inputs = {'obs': data['next_obs']}
self._learn_model.reset(state=data['prev_state'][1])
logit_detach = self._learn_model.forward(next_inputs, single_step=False)['logit'].clone().detach()
next_inputs = {'obs': data['next_obs'], 'action': logit_detach.argmax(dim=-1)}
else:
......
......@@ -32,6 +32,7 @@ main_config = dict(
learning_rate=0.0005,
target_update_theta=0.001,
discount_factor=0.99,
double_q=True,
),
collect=dict(
n_sample=600,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册