template.coma

Please Reference ding/model/template/coma.py for usage

COMAActorNetwork

class ding.model.template.coma.COMAActorNetwork(obs_shape: int, action_shape: int, hidden_size_list: ding.utils.type_helper.SequenceType = [128, 128, 64])[source]
Overview:

Decentralized actor network in COMA

Interface:

__init__, forward

__init__(obs_shape: int, action_shape: int, hidden_size_list: ding.utils.type_helper.SequenceType = [128, 128, 64])[source]
Overview:

initialize COMA actor network

Arguments:
  • obs_shape (int): the dimension of each agent’s observation state

  • action_shape (int): the dimension of action shape

  • hidden_size_list (list): the list of hidden size, default to [128, 128, 64]

forward(inputs: Dict) Dict[source]
ArgumentsKeys:
  • necessary: obs { agent_state, action_mask }, prev_state

ReturnsKeys:
  • necessary: logit, next_state, action_mask

COMACriticNetwork

class ding.model.template.coma.COMACriticNetwork(input_size: int, action_shape: int, hidden_size: int = 128)[source]
Overview:

Centralized critic network in COMA

Interface:

__init__, forward

__init__(input_size: int, action_shape: int, hidden_size: int = 128)[source]
Overview:

initialize COMA critic network

Arguments:
  • input_size (int): the size of input global observation

  • action_shape (int): the dimension of action shape

  • hidden_size_list (list): the list of hidden size, default to 128

_preprocess_data(data: Dict) torch.Tensor[source]
Overview:

preprocess data to make it can be used by MLP net

Arguments:
  • data (dict): input data dict with keys [‘obs’, ‘prev_state’, ‘action’]

  • agent_state (torch.Tensor): each agent local state(obs)

  • global_state (torch.Tensor): global state(obs)

  • action (torch.Tensor): the masked action

ArgumentsKeys:
  • necessary: obs { agent_state, global_state} , action, prev_state

Return:
  • x (torch.Tensor): the data can be used by MLP net, including global_state, agent_state, last_action, action, agent_id

forward(data: Dict) Dict[source]
Overview:

forward computation graph of qmix network

Arguments:
  • data (dict): input data dict with keys [‘obs’, ‘prev_state’, ‘action’]

  • agent_state (torch.Tensor): each agent local state(obs)

  • global_state (torch.Tensor): global state(obs)

  • action (torch.Tensor): the masked action

ArgumentsKeys:
  • necessary: obs { agent_state, global_state }, action, prev_state

ReturnsKeys:
  • necessary: q_value

COMA

class ding.model.template.coma.COMA(agent_num: int, obs_shape: Dict, action_shape: Union[int, ding.utils.type_helper.SequenceType], actor_hidden_size_list: ding.utils.type_helper.SequenceType)[source]
Overview:

COMA network is QAC-type actor-critic.

__init__(agent_num: int, obs_shape: Dict, action_shape: Union[int, ding.utils.type_helper.SequenceType], actor_hidden_size_list: ding.utils.type_helper.SequenceType) None[source]
Overview:

initialize COMA network

Arguments:
  • agent_num (int): the number of agent

  • obs_shape (Dict): the observation information, including agent_state and global_state

  • action_shape (Union[int, SequenceType]): the dimension of action shape

  • actor_hidden_size_list (SequenceType): the list of hidden size

forward(inputs: Dict, mode: str) Dict[source]
ArgumentsKeys:
  • necessary: obs { agent_state, global_state, action_mask }, action, prev_state

ReturnsKeys:
  • necessary:
    • compute_critic: q_value

    • compute_actor: logit, next_state, action_mask