rl_utils.isw

isw

compute_importance_weights

Overview:

Computing importance sampling weight with given output and action

Arguments:
  • target_output (torch.Tensor): the output taking the action by the current policy network, usually this output is network output logit

  • behaviour_output (torch.Tensor): the output taking the action by the behaviour policy network, usually this output is network output logit, which is used to produce the trajectory(collector)

  • action (torch.Tensor): the chosen action(index for the discrete action space) in trajectory, i.e.: behaviour_action

  • requires_grad (bool): whether requires grad computation

Returns:
  • rhos (torch.Tensor): Importance sampling weight

Shapes:
  • target_output (torch.FloatTensor): \((T, B, N)\), where T is timestep, B is batch size and N is action dim

  • behaviour_output (torch.FloatTensor): \((T, B, N)\)

  • action (torch.LongTensor): \((T, B)\)

  • rhos (torch.FloatTensor): \((T, B)\)