QMIX¶
QMIXPolicy¶
- class ding.policy.qmix.QMIXPolicy(cfg: dict, model: Optional[Union[type, torch.nn.modules.module.Module]] = None, enable_field: Optional[List[str]] = None)[source]¶
- Overview:
- Policy class of QMIX algorithm. QMIX is a multi model reinforcement learning algorithm,
you can view the paper in the following link https://arxiv.org/abs/1803.11485
- Interface:
- _init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn
_init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval_reset_eval, _get_train_sample, default_model
- Config:
ID
Symbol
Type
Default Value
Description
Other(Shape)
1
type
str
qmix
RL policy register name, refer toregistryPOLICY_REGISTRY
this arg is optional,a placeholder2
cuda
bool
True
Whether to use cuda for networkthis arg can be diff-erent from modes3
on_policy
bool
False
Whether the RL algorithm is on-policyor off-policypriority
bool
False
Whether use priority(PER)priority sample,update priority5
priority_
IS_weight
bool
False
Whether use Importance SamplingWeight to correct biased update.IS weight6
learn.update_
per_collect
int
20
How many updates(iterations) to trainafter collector’s one collection. Onlyvalid in serial trainingthis args can be varyfrom envs. Bigger valmeans more off-policy7
learn.target_
update_theta
float
0.001
Target network update momentumparameter.between[0,1]8
learn.discount
_factor
float
0.99
Reward’s future discount factor, aka.gammamay be 1 when sparsereward env
- _data_preprocess_learn(data: List[Any]) dict [source]¶
- Overview:
Preprocess the data to fit the required data format for learning
- Arguments:
data (
List[Dict[str, Any]]
): the data collected from collect function
- Returns:
- data (
Dict[str, Any]
): the processed data, from [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])}
- data (
- _forward_collect(data: dict, eps: float) dict [source]¶
- Overview:
Forward function for collect mode with eps_greedy
- Arguments:
- data (
Dict[str, Any]
): Dict type data, stacked env data for predicting policy_output(action), values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
- data (
eps (
float
): epsilon value for exploration, which is decayed by collected env step.
- Returns:
output (
Dict[int, Any]
): Dict type data, including at least inferred action according to input obs.
- ReturnsKeys
necessary:
action
- _forward_eval(data: dict) dict [source]¶
- Overview:
Forward function of eval mode, similar to
self._forward_collect
.- Arguments:
- data (
Dict[str, Any]
): Dict type data, stacked env data for predicting policy_output(action), values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
- data (
- Returns:
output (
Dict[int, Any]
): The dict of predicting action for the interaction with env.
- ReturnsKeys
necessary:
action
- _forward_learn(data: dict) Dict[str, Any] [source]¶
- Overview:
Forward and backward function of learn mode.
- Arguments:
- data (
Dict[str, Any]
): Dict type data, a batch of data for training, values are torch.Tensor or np.ndarray or dict/list combinations.
- data (
- Returns:
- info_dict (
Dict[str, Any]
): Dict type data, a info dict indicated training result, which will be recorded in text log and tensorboard, values are python scalar or a list of scalars.
- info_dict (
- ArgumentsKeys:
necessary:
obs
,next_obs
,action
,reward
,weight
,prev_state
,done
- ReturnsKeys:
- necessary:
cur_lr
,total_loss
cur_lr (
float
): Current learning ratetotal_loss (
float
): The calculated loss
- necessary:
- _get_train_sample(data: list) Union[None, List[Any]] [source]¶
- Overview:
Get the train sample from trajectory.
- Arguments:
data (
list
): The trajectory’s cache
- Returns:
samples (
dict
): The training samples generated
- _init_collect() None [source]¶
- Overview:
Collect mode init method. Called by
self.__init__
. Init traj and unroll length, collect model. Enable the eps_greedy_sample and the hidden_state plugin.
- _init_eval() None [source]¶
- Overview:
Evaluate mode init method. Called by
self.__init__
. Init eval model with argmax strategy and the hidden_state plugin.
- _init_learn() None [source]¶
- Overview:
Learn mode init method. Called by
self.__init__
. Init the learner model of QMIXPolicy- Arguments:
Note
The _init_learn method takes the argument from the self._cfg.learn in the config file
learning_rate (
float
): The learning rate fo the optimizergamma (
float
): The discount factoragent_num (
int
): Since this is a multi-agent algorithm, we need to input the agent num.batch_size (
int
): Need batch size info to init hidden_state plugins
- _load_state_dict_learn(state_dict: Dict[str, Any]) None [source]¶
- Overview:
Load the state_dict variable into policy learn mode.
- Arguments:
state_dict (
Dict[str, Any]
): the dict of policy learn state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _process_transition(obs: Any, model_output: dict, timestep: collections.namedtuple) dict [source]¶
- Overview:
Generate dict type transition data from inputs.
- Arguments:
obs (
Any
): Env observationmodel_output (
dict
): Output of collect model, including at least [‘action’, ‘prev_state’]- timestep (
namedtuple
): Output after env step, including at least [‘obs’, ‘reward’, ‘done’] (here ‘obs’ indicates obs after env step).
- timestep (
- Returns:
- transition (
dict
): Dict type transition data, including ‘obs’, ‘next_obs’, ‘prev_state’, ‘action’, ‘reward’, ‘done’
- transition (
- _reset_collect(data_id: Optional[List[int]] = None) None [source]¶
- Overview:
Reset collect model to the state indicated by data_id
- Arguments:
- data_id (
Optional[List[int]]
): The id that store the state and we will reset the model state to the state indicated by data_id
- data_id (
- _reset_eval(data_id: Optional[List[int]] = None) None [source]¶
- Overview:
Reset eval model to the state indicated by data_id
- Arguments:
- data_id (
Optional[List[int]]
): The id that store the state and we will reset the model state to the state indicated by data_id
- data_id (
- _reset_learn(data_id: Optional[List[int]] = None) None [source]¶
- Overview:
Reset learn model to the state indicated by data_id
- Arguments:
- data_id (
Optional[List[int]]
): The id that store the state and we will reset the model state to the state indicated by data_id
- data_id (
- _state_dict_learn() Dict[str, Any] [source]¶
- Overview:
Return the state_dict of learn mode, usually including model and optimizer.
- Returns:
state_dict (
Dict[str, Any]
): the dict of current policy learn state, for saving and restoring.
- default_model() Tuple[str, List[str]] [source]¶
- Overview:
Return this algorithm default model setting for demonstration.
- Returns:
model_info (
Tuple[str, List[str]]
): model name and mode import_names
Note
The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For QMIX,
ding.model.qmix.qmix