DQN

DQNPolicy

class ding.policy.dqn.DQNPolicy(cfg: dict, model: Optional[Union[type, torch.nn.modules.module.Module]] = None, enable_field: Optional[List[str]] = None)[source]
Overview:

Policy class of DQN algorithm, extended by Double DQN/Dueling DQN/PER/multi-step TD.

Config:

ID

Symbol

Type

Default Value

Description

Other(Shape)

1

type

str

dqn

RL policy register name, refer to
registry POLICY_REGISTRY
This arg is optional,
a placeholder

2

cuda

bool

False

Whether to use cuda for network
This arg can be diff-
erent from modes

3

on_policy

bool

False

Whether the RL algorithm is on-policy
or off-policy

4

priority

bool

False

Whether use priority(PER)
Priority sample,
update priority

5

priority_IS
_weight

bool

False

Whether use Importance Sampling Weight
to correct biased update. If True,
priority must be True.

6

discount_
factor

float

0.97, [0.95, 0.999]

Reward’s future discount factor, aka.
gamma
May be 1 when sparse
reward env

7

nstep

int

1, [3, 5]

N-step reward discount sum for target
q_value estimation

8

learn.update
per_collect

int

3

How many updates(iterations) to train
after collector’s one collection. Only
valid in serial training
This args can be vary
from envs. Bigger val
means more off-policy

9

learn.multi
_gpu

bool

False

whether to use multi gpu during

10

learn.batch_
size

int

64

The number of samples of an iteration

11

learn.learning
_rate

float

0.001

Gradient step length of an iteration.

12

learn.target_
update_freq

int

100

Frequence of target network update.
Hard(assign) update

13

learn.ignore_
done

bool

False

Whether ignore done for target value
calculation.
Enable it for some
fake termination env

14

collect.n_sample

int

[8, 128]

The number of training samples of a
call of collector.
It varies from
different envs

15

collect.unroll
_len

int

1

unroll length of an iteration
In RNN, unroll_len>1

16

other.eps.type

str

exp

exploration rate decay type
Support [‘exp’,
‘linear’].

17

other.eps.
start

float

0.95

start value of exploration rate
[0,1]

18

other.eps.
end

float

0.1

end value of exploration rate
[0,1]

19

other.eps.
decay

int

10000

decay length of exploration
greater than 0. set
decay=10000 means
the exploration rate
decay from start
value to end value
during decay length.
_forward_collect(data: Dict[int, Any], eps: float) Dict[int, Any][source]
Overview:

Forward computation graph of collect mode(collect training data), with eps_greedy for exploration.

Arguments:
  • data (Dict[str, Any]): Dict type data, stacked env data for predicting policy_output(action), values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.

  • eps (float): epsilon value for exploration, which is decayed by collected env step.

Returns:
  • output (Dict[int, Any]): The dict of predicting policy_output(action) for the interaction with env and the constructing of transition.

ArgumentsKeys:
  • necessary: obs

ReturnsKeys
  • necessary: logit, action

_forward_eval(data: Dict[int, Any]) Dict[int, Any][source]
Overview:

Forward computation graph of eval mode(evaluate policy performance), at most cases, it is similar to self._forward_collect.

Arguments:
  • data (Dict[str, Any]): Dict type data, stacked env data for predicting policy_output(action), values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.

Returns:
  • output (Dict[int, Any]): The dict of predicting action for the interaction with env.

ArgumentsKeys:
  • necessary: obs

ReturnsKeys
  • necessary: action

_forward_learn(data: Dict[str, Any]) Dict[str, Any][source]
Overview:

Forward computation graph of learn mode(updating policy).

Arguments:
  • data (Dict[str, Any]): Dict type data, a batch of data for training, values are torch.Tensor or np.ndarray or dict/list combinations.

Returns:
  • info_dict (Dict[str, Any]): Dict type data, a info dict indicated training result, which will be recorded in text log and tensorboard, values are python scalar or a list of scalars.

ArgumentsKeys:
  • necessary: obs, action, reward, next_obs, done

  • optional: value_gamma, IS

ReturnsKeys:
  • necessary: cur_lr, total_loss, priority

  • optional: action_distribution

_get_train_sample(data: List[Dict[str, Any]]) List[Dict[str, Any]][source]
Overview:

For a given trajectory(transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) or some continuous transitions(DRQN).

Arguments:
  • data (List[Dict[str, Any]): The trajectory data(a list of transition), each element is the same format as the return value of self._process_transition method.

Returns:
  • samples (dict): The list of training samples.

Note

We will vectorize process_transition and get_train_sample method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself.

_init_collect() None[source]
Overview:

Collect mode init method. Called by self.__init__, initialize algorithm arguments and collect_model, enable the eps_greedy_sample for exploration.

_init_eval() None[source]
Overview:

Evaluate mode init method. Called by self.__init__, initialize eval_model.

_init_learn() None[source]
Overview:

Learn mode init method. Called by self.__init__, initialize the optimizer, algorithm arguments, main and target model.

_load_state_dict_learn(state_dict: Dict[str, Any]) None[source]
Overview:

Load the state_dict variable into policy learn mode.

Arguments:
  • state_dict (Dict[str, Any]): the dict of policy learn state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_process_transition(obs: Any, policy_output: Dict[str, Any], timestep: collections.namedtuple) Dict[str, Any][source]
Overview:

Generate a transition(e.g.: <s, a, s’, r, d>) for this algorithm training.

Arguments:
  • obs (Any): Env observation.

  • policy_output (Dict[str, Any]): The output of policy collect mode(self._forward_collect), including at least action.

  • timestep (namedtuple): The output after env step(execute policy output action), including at least obs, reward, done, (here obs indicates obs after env step).

Returns:
  • transition (dict): Dict type transition data.

_state_dict_learn() Dict[str, Any][source]
Overview:

Return the state_dict of learn mode, usually including model and optimizer.

Returns:
  • state_dict (Dict[str, Any]): the dict of current policy learn state, for saving and restoring.

default_model() Tuple[str, List[str]][source]
Overview:

Return this algorithm default model setting for demonstration.

Returns:
  • model_info (Tuple[str, List[str]]): model name and mode import_names

Note

The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For DQN, ding.model.template.q_learning.DQN