rl_utils.adder¶
adder¶
Adder¶
- class ding.rl_utils.adder.Adder[source]¶
- Overview:
Adder is a component that handles different transformations and calculations for transitions in Collector Module(data generation and processing), such as GAE, n-step return, transition sampling etc.
- Interface:
__init__, get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample
- classmethod get_gae(data: List[Dict[str, Any]], last_value: torch.Tensor, gamma: float, gae_lambda: float, cuda: bool) List[Dict[str, Any]] [source]¶
- Overview:
Get GAE advantage for stacked transitions(T timestep, 1 batch). Call
gae
for calculation.- Arguments:
data (
list
): Transitions list, each element is a transition dict with at least [‘value’, ‘reward’]last_value (
torch.Tensor
): The last value(i.e.: the T+1 timestep)gamma (
float
): The future discount factorgae_lambda (
float
): GAE lambda parametercuda (
bool
): Whether use cuda in GAE computation
- Returns:
data (
list
): transitions list like input one, but each element owns extra advantage key ‘adv’
- classmethod get_gae_with_default_last_value(data: collections.deque, done: bool, gamma: float, gae_lambda: float, cuda: bool) List[Dict[str, Any]] [source]¶
- Overview:
Like
get_gae
above to get GAE advantage for stacked transitions. However, this function is designed in caselast_value
is not passed. If transition is not done yet, it wouold assign last value indata
aslast_value
, discard the last element indata``(i.e. len(data) would decrease by 1), and then call ``get_gae
. Otherwise it would makelast_value
equal to 0.- Arguments:
data (
deque
): Transitions list, each element is a transition dict with at least[‘value’, ‘reward’]done (
bool
): Whether the transition reaches the end of an episode(i.e. whether the env is done)gamma (
float
): The future discount factorgae_lambda (
float
): GAE lambda parametercuda (
bool
): Whether use cuda in GAE computation
- Returns:
data (
List[Dict[str, Any]]
): transitions list like input one, but each element owns extra advantage key ‘adv’
- classmethod get_nstep_return_data(data: collections.deque, nstep: int, cum_reward=False, correct_terminate_gamma=True, gamma=0.99) collections.deque [source]¶
- Overview:
Process raw traj data by updating keys [‘next_obs’, ‘reward’, ‘done’] in data’s dict element.
- Arguments:
data (
deque
): Transitions list, each element is a transition dictnstep (
int
): Number of steps. If equals to 1, returndata
directly; Otherwise update with nstep value.
- Returns:
data (
deque
): Transitions list like input one, but each element updated with nstep value.
- classmethod get_train_sample(data: List[Dict[str, Any]], unroll_len: int, last_fn_type: str = 'last', null_transition: Optional[dict] = None) List[Dict[str, Any]] [source]¶
- Overview:
Process raw traj data by updating keys [‘next_obs’, ‘reward’, ‘done’] in data’s dict element. If
unroll_len
equals to 1, which means no process is needed, can directly returndata
. Otherwise,data
will be splitted according tounroll_len
, process residual part according tolast_fn_type
and calllists_to_dicts
to form sampled training data.- Arguments:
data (
List[Dict[str, Any]]
): Transitions list, each element is a transition dictunroll_len (
int
): Learn training unroll lengthlast_fn_type (
str
): The method type name for dealing with last residual data in a traj after splitting, should be in [‘last’, ‘drop’, ‘null_padding’]null_transition (
Optional[dict]
): Dict type null transition, used innull_padding
- Returns:
data (
List[Dict[str, Any]]
): Transitions list processed after unrolling