rl_tuils.vtrace¶
vtrace¶
vtrace_error¶
- Overview:
Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures), (arXiv:1802.01561)
- Arguments:
- data (
namedtuple
): input data with fields shown invtrace_data
target_output (
torch.Tensor
): the output taking the action by the current policy network, usually this output is network output logitbehaviour_output (
torch.Tensor
): the output taking the action by the behaviour policy network, usually this output is network output logit, which is used to produce the trajectory(collector)action (
torch.Tensor
): the chosen action(index for the discrete action space) in trajectory, i.e.: behaviour_action
- data (
gamma: (
float
): the future discount factor, defaults to 0.95lambda: (
float
): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0rho_clip_ratio (
float
): the clipping threshold for importance weights (rho) when calculating the baseline targets (vs)c_clip_ratio (
float
): the clipping threshold for importance weights (c) when calculating the baseline targets (vs)rho_pg_clip_ratio (
float
): the clipping threshold for importance weights (rho) when calculating the policy gradient advantage
- Returns:
trace_loss (
namedtuple
): the vtrace loss item, all of them are the differentiable 0-dim tensor
- Shapes:
target_output (
torch.FloatTensor
): \((T, B, N)\), where T is timestep, B is batch size and N is action dimbehaviour_output (
torch.FloatTensor
): \((T, B, N)\)action (
torch.LongTensor
): \((T, B)\)value (
torch.FloatTensor
): \((T+1, B)\)reward (
torch.LongTensor
): \((T, B)\)weight (
torch.LongTensor
): \((T, B)\)