CollaQ

CollaQPolicy

class ding.policy.collaq.CollaQPolicy(cfg: dict, model: Optional[Union[type, torch.nn.modules.module.Module]] = None, enable_field: Optional[List[str]] = None)[source]
Overview:

Policy class of CollaQ algorithm. CollaQ is a multi-agent reinforcement learning algorithm

Interface:
_init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn

_init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval_reset_eval, _get_train_sample, default_model

Config:

ID

Symbol

Type

Default Value

Description

Other(Shape)

1

type

str

collaq

RL policy register name, refer to
registry POLICY_REGISTRY
this arg is optional,
a placeholder

2

cuda

bool

True

Whether to use cuda for network
this arg can be diff-
erent from modes

3

on_policy

bool

False

Whether the RL algorithm is on-policy
or off-policy

priority

bool

False

Whether use priority(PER)
priority sample,
update priority

5

priority_
IS_weight

bool

False

Whether use Importance Sampling
Weight to correct biased update.
IS weight

6

learn.update_
per_collect

int

20

How many updates(iterations) to train
after collector’s one collection. Only
valid in serial training
this args can be vary
from envs. Bigger val
means more off-policy

7

learn.target_
update_theta

float

0.001

Target network update momentum
parameter.
between[0,1]

8

learn.discount
_factor

float

0.99

Reward’s future discount factor, aka.
gamma
may be 1 when sparse
reward env

9

learn.collaq
_loss_weight

float

1.0

The weight of collaq MARA loss
_data_preprocess_learn(data: List[Any]) dict[source]
Overview:

Preprocess the data to fit the required data format for learning

Arguments:
  • data (List[Dict[str, Any]]): the data collected from collect function

Returns:
  • data (Dict[str, Any]): the processed data, from

    [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])}

_forward_collect(data: dict, eps: float) dict[source]
Overview:

Forward function for collect mode with eps_greedy

Arguments:
  • data (Dict[str, Any]): Dict type data, stacked env data for predicting policy_output(action),

    values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.

  • eps (float): epsilon value for exploration, which is decayed by collected env step.

Returns:
  • output (Dict[int, Any]): Dict type data, including at least inferred action according to input obs.

ReturnsKeys
  • necessary: action

_forward_eval(data: dict) dict[source]
Overview:

Forward function for eval mode, similar to self._forward_collect.

Arguments:
  • data (Dict[str, Any]): Dict type data, stacked env data for predicting policy_output(action),

    values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.

Returns:
  • output (Dict[int, Any]): The dict of predicting action for the interaction with env.

ReturnsKeys
  • necessary: action

_forward_learn(data: dict) Dict[str, Any][source]
Overview:

Forward and backward function of learn mode.

Arguments:
  • data (Dict[str, Any]): Dict type data, a batch of data for training, values are torch.Tensor or

    np.ndarray or dict/list combinations.

Returns:
  • info_dict (Dict[str, Any]): Dict type data, a info dict indicated training result, which will be

    recorded in text log and tensorboard, values are python scalar or a list of scalars.

ArgumentsKeys:
  • necessary: obs, next_obs, action, reward, weight, prev_state, done

ReturnsKeys:
  • necessary: cur_lr, total_loss
    • cur_lr (float): Current learning rate

    • total_loss (float): The calculated loss

_get_train_sample(data: list) Union[None, List[Any]][source]
Overview:

Get the train sample from trajectory.

Arguments:
  • data (list): The trajectory’s cache

Returns:
  • samples (dict): The training samples generated

_init_collect() None[source]
Overview:

Collect mode init method. Called by self.__init__. Init traj and unroll length, collect model. Enable the eps_greedy_sample and the hidden_state plugin.

_init_eval() None[source]
Overview:

Evaluate mode init method. Called by self.__init__. Init eval model with argmax strategy and the hidden_state plugin.

_init_learn() None[source]
Overview:

Learn mode init method. Called by self.__init__. Init the learner model of CollaQPolicy

Arguments:

Note

The _init_learn method takes the argument from the self._cfg.learn in the config file

  • learning_rate (float): The learning rate fo the optimizer

  • gamma (float): The discount factor

  • alpha (float): The collaQ loss factor, the weight for calculating MARL loss

  • agent_num (int): Since this is a multi-agent algorithm, we need to input the agent num.

  • batch_size (int): Need batch size info to init hidden_state plugins

_load_state_dict_learn(state_dict: Dict[str, Any]) None[source]
Overview:

Load the state_dict variable into policy learn mode.

Arguments:
  • state_dict (Dict[str, Any]): the dict of policy learn state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_process_transition(obs: Any, model_output: dict, timestep: collections.namedtuple) dict[source]
Overview:

Generate dict type transition data from inputs.

Arguments:
  • obs (Any): Env observation

  • model_output (dict): Output of collect model, including at least

    [‘action’, ‘prev_state’, ‘agent_colla_alone_q’]

  • timestep (namedtuple): Output after env step, including at least [‘obs’, ‘reward’, ‘done’]

    (here ‘obs’ indicates obs after env step).

Returns:
  • transition (dict): Dict type transition data.

_reset_collect(data_id: Optional[List[int]] = None) None[source]
Overview:

Reset collect model to the state indicated by data_id

Arguments:
  • data_id (Optional[List[int]]): The id that store the state and we will reset

    the model state to the state indicated by data_id

_reset_eval(data_id: Optional[List[int]] = None) None[source]
Overview:

Reset eval model to the state indicated by data_id

Arguments:
  • data_id (Optional[List[int]]): The id that store the state and we will reset

    the model state to the state indicated by data_id

_reset_learn(data_id: Optional[List[int]] = None) None[source]
Overview:

Reset learn model to the state indicated by data_id

Arguments:
  • data_id (Optional[List[int]]): The id that store the state and we will reset

    the model state to the state indicated by data_id

_state_dict_learn() Dict[str, Any][source]
Overview:

Return the state_dict of learn mode, usually including model and optimizer.

Returns:
  • state_dict (Dict[str, Any]): the dict of current policy learn state, for saving and restoring.

default_model() Tuple[str, List[str]][source]
Overview:

Return this algorithm default model setting for demonstration.

Returns:
  • model_info (Tuple[str, List[str]]): model name and mode import_names

Note

The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For collaq, ding.model.qmix.qmix