diff --git a/ding/model/template/qmix.py b/ding/model/template/qmix.py index 67b62a2f7f6b499d693651f6bd4ffa100a5e5383..50514efb3304b609668d92771d1609f5fd962283 100644 --- a/ding/model/template/qmix.py +++ b/ding/model/template/qmix.py @@ -121,7 +121,7 @@ class QMix(nn.Module): self.mixer = mixer if self.mixer: self._mixer = Mixer(agent_num, global_obs_shape, embedding_size) - self._global_state_encoder = nn.Sequential() + self._global_state_encoder = nn.Identity() def forward(self, data: dict, single_step: bool = True) -> dict: """ @@ -428,7 +428,7 @@ class CollaQ(nn.Module): embedding_size = hidden_size_list[-1] if self.mixer: self._mixer = Mixer(agent_num, global_obs_shape, embedding_size) - self._global_state_encoder = nn.Sequential() + self._global_state_encoder = nn.Identity() def forward(self, data: dict, single_step: bool = True) -> dict: """ diff --git a/ding/policy/qmix.py b/ding/policy/qmix.py index 5d5ac822d44fa98c89b08f3afef10ae0e03d8922..263695f7f3b9eef80add03f31b04bbc93144924a 100644 --- a/ding/policy/qmix.py +++ b/ding/policy/qmix.py @@ -197,6 +197,7 @@ class QMIXPolicy(Policy): if self._cfg.learn.double_q: next_inputs = {'obs': data['next_obs']} + self._learn_model.reset(state=data['prev_state'][1]) logit_detach = self._learn_model.forward(next_inputs, single_step=False)['logit'].clone().detach() next_inputs = {'obs': data['next_obs'], 'action': logit_detach.argmax(dim=-1)} else: diff --git a/dizoo/multiagent_particle/config/cooperative_navigation_qmix_config.py b/dizoo/multiagent_particle/config/cooperative_navigation_qmix_config.py index 47158507b9768262c934ded0d36d6eb169703c8f..f41bad9fe8f471b47c9bb1f3bf160a3c1594af30 100644 --- a/dizoo/multiagent_particle/config/cooperative_navigation_qmix_config.py +++ b/dizoo/multiagent_particle/config/cooperative_navigation_qmix_config.py @@ -32,6 +32,7 @@ main_config = dict( learning_rate=0.0005, target_update_theta=0.001, discount_factor=0.99, + double_q=True, ), collect=dict( n_sample=600,