R2D2¶
R2D2Policy¶
- class ding.policy.r2d2.R2D2Policy(cfg: dict, model: Optional[Union[type, torch.nn.modules.module.Module]] = None, enable_field: Optional[List[str]] = None)[source]¶
- Overview:
Policy class of R2D2, from paper Recurrent Experience Replay in Distributed Reinforcement Learning . R2D2 proposes that several tricks should be used to improve upon DRQN, namely some recurrent experience replay tricks such as burn-in.
- Config:
ID
Symbol
Type
Default Value
Description
Other(Shape)
1
type
str
dqn
RL policy register name, refer toregistryPOLICY_REGISTRY
This arg is optional,a placeholder2
cuda
bool
False
Whether to use cuda for networkThis arg can be diff-erent from modes3
on_policy
bool
False
Whether the RL algorithm is on-policyor off-policy4
priority
bool
False
Whether use priority(PER)Priority sample,update priority5
priority_IS
_weight
bool
False
Whether use Importance Sampling Weightto correct biased update. If True,priority must be True.6
discount_
factor
float
0.997, [0.95, 0.999]
Reward’s future discount factor, aka.gammaMay be 1 when sparsereward env7
nstep
int
3, [3, 5]
N-step reward discount sum for targetq_value estimation8
burnin_step
int
2
The timestep of burnin operation,which is designed to RNN hidden statedifference caused by off-policy9
learn.update
per_collect
int
1
How many updates(iterations) to trainafter collector’s one collection. Onlyvalid in serial trainingThis args can be varyfrom envs. Bigger valmeans more off-policy10
learn.batch_
size
int
64
The number of samples of an iteration11
learn.learning
_rate
float
0.001
Gradient step length of an iteration.12
learn.value_
rescale
bool
True
Whether use value_rescale function forpredicted value13
learn.target_
update_freq
int
100
Frequence of target network update.Hard(assign) update14
learn.ignore_
done
bool
False
Whether ignore done for target valuecalculation.Enable it for somefake termination env15
collect.n_sample
int
[8, 128]
The number of training samples of acall of collector.It varies fromdifferent envs16
collect.unroll
_len
int
1
unroll length of an iterationIn RNN, unroll_len>1
- _forward_collect(data: dict, eps: float) dict [source]¶
- Overview:
Forward function for collect mode with eps_greedy
- Arguments:
- data (
Dict[str, Any]
): Dict type data, stacked env data for predicting policy_output(action), values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
- data (
eps (
float
): epsilon value for exploration, which is decayed by collected env step.
- Returns:
output (
Dict[int, Any]
): Dict type data, including at least inferred action according to input obs.
- ReturnsKeys
necessary:
action
- _forward_eval(data: dict) dict [source]¶
- Overview:
Forward function of eval mode, similar to
self._forward_collect
.- Arguments:
- data (
Dict[str, Any]
): Dict type data, stacked env data for predicting policy_output(action), values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
- data (
- Returns:
output (
Dict[int, Any]
): The dict of predicting action for the interaction with env.
- ReturnsKeys
necessary:
action
- _forward_learn(data: dict) Dict[str, Any] [source]¶
- Overview:
Forward and backward function of learn mode. Acquire the data, calculate the loss and optimize learner model.
- Arguments:
- data (
dict
): Dict type data, including at least [‘main_obs’, ‘target_obs’, ‘burnin_obs’, ‘action’, ‘reward’, ‘done’, ‘weight’]
- data (
- Returns:
- info_dict (
Dict[str, Any]
): Including cur_lr and total_loss cur_lr (
float
): Current learning ratetotal_loss (
float
): The calculated loss
- info_dict (
- _init_collect() None [source]¶
- Overview:
Collect mode init method. Called by
self.__init__
. Init traj and unroll length, collect model.
- _init_eval() None [source]¶
- Overview:
Evaluate mode init method. Called by
self.__init__
. Init eval model with argmax strategy.
- _init_learn() None [source]¶
- Overview:
Init the learner model of R2D2Policy
- Arguments:
Note
The _init_learn method takes the argument from the self._cfg.learn in the config file
learning_rate (
float
): The learning rate fo the optimizergamma (
float
): The discount factornstep (
int
): The num of n step returnvalue_rescale (
bool
): Whether to use value rescaled loss in algorithmburnin_step (
int
): The num of step of burnin
- _process_transition(obs: Any, model_output: dict, timestep: collections.namedtuple) dict [source]¶
- Overview:
Generate dict type transition data from inputs.
- Arguments:
obs (
Any
): Env observationmodel_output (
dict
): Output of collect model, including at least [‘action’, ‘prev_state’]- timestep (
namedtuple
): Output after env step, including at least [‘reward’, ‘done’] (here ‘obs’ indicates obs after env step).
- timestep (
- Returns:
transition (
dict
): Dict type transition data.