From 835e3c4c488d6d4af32853dd91e30f5e2c281c64 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Thu, 9 Sep 2021 19:27:04 +0800 Subject: [PATCH] fix(nyz): fix qmix double_q hidden state bug --- ding/model/template/qmix.py | 4 ++-- ding/policy/qmix.py | 1 + .../config/cooperative_navigation_qmix_config.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ding/model/template/qmix.py b/ding/model/template/qmix.py index 67b62a2f..50514efb 100644 --- a/ding/model/template/qmix.py +++ b/ding/model/template/qmix.py @@ -121,7 +121,7 @@ class QMix(nn.Module): self.mixer = mixer if self.mixer: self._mixer = Mixer(agent_num, global_obs_shape, embedding_size) - self._global_state_encoder = nn.Sequential() + self._global_state_encoder = nn.Identity() def forward(self, data: dict, single_step: bool = True) -> dict: """ @@ -428,7 +428,7 @@ class CollaQ(nn.Module): embedding_size = hidden_size_list[-1] if self.mixer: self._mixer = Mixer(agent_num, global_obs_shape, embedding_size) - self._global_state_encoder = nn.Sequential() + self._global_state_encoder = nn.Identity() def forward(self, data: dict, single_step: bool = True) -> dict: """ diff --git a/ding/policy/qmix.py b/ding/policy/qmix.py index 5d5ac822..263695f7 100644 --- a/ding/policy/qmix.py +++ b/ding/policy/qmix.py @@ -197,6 +197,7 @@ class QMIXPolicy(Policy): if self._cfg.learn.double_q: next_inputs = {'obs': data['next_obs']} + self._learn_model.reset(state=data['prev_state'][1]) logit_detach = self._learn_model.forward(next_inputs, single_step=False)['logit'].clone().detach() next_inputs = {'obs': data['next_obs'], 'action': logit_detach.argmax(dim=-1)} else: diff --git a/dizoo/multiagent_particle/config/cooperative_navigation_qmix_config.py b/dizoo/multiagent_particle/config/cooperative_navigation_qmix_config.py index 47158507..f41bad9f 100644 --- a/dizoo/multiagent_particle/config/cooperative_navigation_qmix_config.py +++ b/dizoo/multiagent_particle/config/cooperative_navigation_qmix_config.py @@ -32,6 +32,7 @@ main_config = dict( learning_rate=0.0005, target_update_theta=0.001, discount_factor=0.99, + double_q=True, ), collect=dict( n_sample=600, -- GitLab