rl_tuils.vtrace

vtrace

vtrace_error

Overview:

Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures), (arXiv:1802.01561)

Arguments:
  • data (namedtuple): input data with fields shown in vtrace_data
    • target_output (torch.Tensor): the output taking the action by the current policy network, usually this output is network output logit

    • behaviour_output (torch.Tensor): the output taking the action by the behaviour policy network, usually this output is network output logit, which is used to produce the trajectory(collector)

    • action (torch.Tensor): the chosen action(index for the discrete action space) in trajectory, i.e.: behaviour_action

  • gamma: (float): the future discount factor, defaults to 0.95

  • lambda: (float): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0

  • rho_clip_ratio (float): the clipping threshold for importance weights (rho) when calculating the baseline targets (vs)

  • c_clip_ratio (float): the clipping threshold for importance weights (c) when calculating the baseline targets (vs)

  • rho_pg_clip_ratio (float): the clipping threshold for importance weights (rho) when calculating the policy gradient advantage

Returns:
  • trace_loss (namedtuple): the vtrace loss item, all of them are the differentiable 0-dim tensor

Shapes:
  • target_output (torch.FloatTensor): \((T, B, N)\), where T is timestep, B is batch size and N is action dim

  • behaviour_output (torch.FloatTensor): \((T, B, N)\)

  • action (torch.LongTensor): \((T, B)\)

  • value (torch.FloatTensor): \((T+1, B)\)

  • reward (torch.LongTensor): \((T, B)\)

  • weight (torch.LongTensor): \((T, B)\)