DQN¶
DQNPolicy¶
- class ding.policy.dqn.DQNPolicy(cfg: dict, model: Optional[Union[type, torch.nn.modules.module.Module]] = None, enable_field: Optional[List[str]] = None)[source]¶
- Overview:
Policy class of DQN algorithm, extended by Double DQN/Dueling DQN/PER/multi-step TD.
- Config:
ID
Symbol
Type
Default Value
Description
Other(Shape)
1
type
str
dqn
RL policy register name, refer toregistryPOLICY_REGISTRY
This arg is optional,a placeholder2
cuda
bool
False
Whether to use cuda for networkThis arg can be diff-erent from modes3
on_policy
bool
False
Whether the RL algorithm is on-policyor off-policy4
priority
bool
False
Whether use priority(PER)Priority sample,update priority5
priority_IS
_weight
bool
False
Whether use Importance Sampling Weightto correct biased update. If True,priority must be True.6
discount_
factor
float
0.97, [0.95, 0.999]
Reward’s future discount factor, aka.gammaMay be 1 when sparsereward env7
nstep
int
1, [3, 5]
N-step reward discount sum for targetq_value estimation8
learn.update
per_collect
int
3
How many updates(iterations) to trainafter collector’s one collection. Onlyvalid in serial trainingThis args can be varyfrom envs. Bigger valmeans more off-policy9
learn.multi
_gpu
bool
False
whether to use multi gpu during10
learn.batch_
size
int
64
The number of samples of an iteration11
learn.learning
_rate
float
0.001
Gradient step length of an iteration.12
learn.target_
update_freq
int
100
Frequence of target network update.Hard(assign) update13
learn.ignore_
done
bool
False
Whether ignore done for target valuecalculation.Enable it for somefake termination env14
collect.n_sample
int
[8, 128]
The number of training samples of acall of collector.It varies fromdifferent envs15
collect.unroll
_len
int
1
unroll length of an iterationIn RNN, unroll_len>116
other.eps.type
str
exp
exploration rate decay typeSupport [‘exp’,‘linear’].17
other.eps.
start
float
0.95
start value of exploration rate[0,1]18
other.eps.
end
float
0.1
end value of exploration rate[0,1]19
other.eps.
decay
int
10000
decay length of explorationgreater than 0. setdecay=10000 meansthe exploration ratedecay from startvalue to end valueduring decay length.
- _forward_collect(data: Dict[int, Any], eps: float) Dict[int, Any] [source]¶
- Overview:
Forward computation graph of collect mode(collect training data), with eps_greedy for exploration.
- Arguments:
data (
Dict[str, Any]
): Dict type data, stacked env data for predicting policy_output(action), values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.eps (
float
): epsilon value for exploration, which is decayed by collected env step.
- Returns:
output (
Dict[int, Any]
): The dict of predicting policy_output(action) for the interaction with env and the constructing of transition.
- ArgumentsKeys:
necessary:
obs
- ReturnsKeys
necessary:
logit
,action
- _forward_eval(data: Dict[int, Any]) Dict[int, Any] [source]¶
- Overview:
Forward computation graph of eval mode(evaluate policy performance), at most cases, it is similar to
self._forward_collect
.- Arguments:
data (
Dict[str, Any]
): Dict type data, stacked env data for predicting policy_output(action), values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
- Returns:
output (
Dict[int, Any]
): The dict of predicting action for the interaction with env.
- ArgumentsKeys:
necessary:
obs
- ReturnsKeys
necessary:
action
- _forward_learn(data: Dict[str, Any]) Dict[str, Any] [source]¶
- Overview:
Forward computation graph of learn mode(updating policy).
- Arguments:
data (
Dict[str, Any]
): Dict type data, a batch of data for training, values are torch.Tensor or np.ndarray or dict/list combinations.
- Returns:
info_dict (
Dict[str, Any]
): Dict type data, a info dict indicated training result, which will be recorded in text log and tensorboard, values are python scalar or a list of scalars.
- ArgumentsKeys:
necessary:
obs
,action
,reward
,next_obs
,done
optional:
value_gamma
,IS
- ReturnsKeys:
necessary:
cur_lr
,total_loss
,priority
optional:
action_distribution
- _get_train_sample(data: List[Dict[str, Any]]) List[Dict[str, Any]] [source]¶
- Overview:
For a given trajectory(transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) or some continuous transitions(DRQN).
- Arguments:
data (
List[Dict[str, Any]
): The trajectory data(a list of transition), each element is the same format as the return value ofself._process_transition
method.
- Returns:
samples (
dict
): The list of training samples.
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself.
- _init_collect() None [source]¶
- Overview:
Collect mode init method. Called by
self.__init__
, initialize algorithm arguments and collect_model, enable the eps_greedy_sample for exploration.
- _init_eval() None [source]¶
- Overview:
Evaluate mode init method. Called by
self.__init__
, initialize eval_model.
- _init_learn() None [source]¶
- Overview:
Learn mode init method. Called by
self.__init__
, initialize the optimizer, algorithm arguments, main and target model.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None [source]¶
- Overview:
Load the state_dict variable into policy learn mode.
- Arguments:
state_dict (
Dict[str, Any]
): the dict of policy learn state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _process_transition(obs: Any, policy_output: Dict[str, Any], timestep: collections.namedtuple) Dict[str, Any] [source]¶
- Overview:
Generate a transition(e.g.: <s, a, s’, r, d>) for this algorithm training.
- Arguments:
obs (
Any
): Env observation.policy_output (
Dict[str, Any]
): The output of policy collect mode(self._forward_collect
), including at leastaction
.timestep (
namedtuple
): The output after env step(execute policy output action), including at leastobs
,reward
,done
, (here obs indicates obs after env step).
- Returns:
transition (
dict
): Dict type transition data.
- _state_dict_learn() Dict[str, Any] [source]¶
- Overview:
Return the state_dict of learn mode, usually including model and optimizer.
- Returns:
state_dict (
Dict[str, Any]
): the dict of current policy learn state, for saving and restoring.
- default_model() Tuple[str, List[str]] [source]¶
- Overview:
Return this algorithm default model setting for demonstration.
- Returns:
model_info (
Tuple[str, List[str]]
): model name and mode import_names
Note
The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For DQN,
ding.model.template.q_learning.DQN