rl_utils.isw¶
isw¶
compute_importance_weights¶
- Overview:
Computing importance sampling weight with given output and action
- Arguments:
target_output (
torch.Tensor
): the output taking the action by the current policy network, usually this output is network output logitbehaviour_output (
torch.Tensor
): the output taking the action by the behaviour policy network, usually this output is network output logit, which is used to produce the trajectory(collector)action (
torch.Tensor
): the chosen action(index for the discrete action space) in trajectory, i.e.: behaviour_actionrequires_grad (
bool
): whether requires grad computation
- Returns:
rhos (
torch.Tensor
): Importance sampling weight
- Shapes:
target_output (
torch.FloatTensor
): \((T, B, N)\), where T is timestep, B is batch size and N is action dimbehaviour_output (
torch.FloatTensor
): \((T, B, N)\)action (
torch.LongTensor
): \((T, B)\)rhos (
torch.FloatTensor
): \((T, B)\)