提交 e2794fcb 编写于 作者: N niuyazhe

hotfix(nyz): fix qtran unittest import bug and qtran hidden size list bug

上级 1797568a
......@@ -47,8 +47,7 @@ class QTran(nn.Module):
super(QTran, self).__init__()
self._act = nn.ReLU()
self._q_network = DRQN(obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling)
assert len(hidden_size_list) == 1
q_input_size = global_obs_shape + hidden_size_list[0] + action_shape
q_input_size = global_obs_shape + hidden_size_list[-1] + action_shape
self.Q = nn.Sequential(
nn.Linear(q_input_size, embedding_size), nn.ReLU(), nn.Linear(embedding_size, embedding_size), nn.ReLU(),
nn.Linear(embedding_size, 1)
......@@ -59,7 +58,7 @@ class QTran(nn.Module):
nn.Linear(global_obs_shape, embedding_size), nn.ReLU(), nn.Linear(embedding_size, embedding_size),
nn.ReLU(), nn.Linear(embedding_size, 1)
)
ae_input = hidden_size_list[0] + action_shape
ae_input = hidden_size_list[-1] + action_shape
self.action_encoding = nn.Sequential(nn.Linear(ae_input, ae_input), nn.ReLU(), nn.Linear(ae_input, ae_input))
def forward(self, data: dict, single_step: bool = True) -> dict:
......
......@@ -3,3 +3,4 @@ from .cooperative_navigation_vdn_config import cooperative_navigation_vdn_config
from .cooperative_navigation_coma_config import cooperative_navigation_coma_config, cooperative_navigation_coma_create_config
from .cooperative_navigation_collaq_config import cooperative_navigation_collaq_config, cooperative_navigation_collaq_create_config
from .cooperative_navigation_atoc_config import cooperative_navigation_atoc_config, cooperative_navigation_atoc_create_config
from .cooperative_navigation_qtran_config import cooperative_navigation_qtran_config, cooperative_navigation_qtran_create_config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册